From 954d9f3c417a64fb33def595fe305397fcc1a4e7 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 14 Nov 2023 15:20:13 +0800 Subject: [PATCH 01/14] update flash-context-attention --- colossalai/kernel/triton/context_attention.py | 203 +++++++++++++++++- requirements/requirements.txt | 1 + .../triton/test_llama_context_attention.py | 4 +- 3 files changed, 204 insertions(+), 4 deletions(-) diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 5ce6f2c21385..023185d429b2 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -136,6 +136,178 @@ def _context_flash_attention_kernel( out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return + + @triton.jit + def _bloom_fwd_kernel_2_1_0( + Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen, + Out, + stride_qbs, stride_qh, stride_qd, + stride_kbs, stride_kh, stride_kd, + stride_vbs, stride_vh, stride_vd, + stride_obs, stride_oh, stride_od, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + + q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + alibi_m = tl.load(Alibi + cur_head) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) + qk -= alibi_loc * alibi_m + + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + + @triton.jit + def _llama_fwd_kernel_2_1_0( + Q, K, V, sm_scale, B_Start_Loc, B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 + Out, + stride_qbs, stride_qh, stride_qd, + stride_kbs, stride_kh, stride_kd, + stride_vbs, stride_vh, stride_vd, + stride_obs, stride_oh, stride_od, + kv_group_num, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + + q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) + # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return @torch.no_grad() def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None): @@ -189,7 +361,21 @@ def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, al num_stages=1, ) else: - raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0") + + kv_group_num = q.shape[1] // k.shape[1] + _bloom_fwd_kernel_2_1_0[grid]( + q, k, v, sm_scale, alibi, b_start_loc, b_seq_len, + o, + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), + o.stride(0), o.stride(1), o.stride(2), + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) return @@ -244,6 +430,19 @@ def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): num_stages=1, ) else: - raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0") + kv_group_num = q.shape[1] // k.shape[1] + _llama_fwd_kernel_2_1_0[grid]( + q, k, v, sm_scale, b_start_loc, b_seq_len, + o, + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), + o.stride(0), o.stride(1), o.stride(2), + kv_group_num=kv_group_num, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1,) return \ No newline at end of file diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 095617d76355..f3d58ea72c60 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -16,3 +16,4 @@ ray sentencepiece google protobuf +pytest diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py index be6de6db2471..4250bed3dad5 100644 --- a/tests/test_infer_ops/triton/test_llama_context_attention.py +++ b/tests/test_infer_ops/triton/test_llama_context_attention.py @@ -41,9 +41,9 @@ def test_llama_context_attention(): llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len) torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - + assert torch.allclose( - torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3 + torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2 ), "outputs from triton and torch are not matched" From de0fe78df64778efb33a0faaca613b338d0794cb Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 14 Nov 2023 18:43:28 +0800 Subject: [PATCH 02/14] adding kernels --- colossalai/kernel/triton/context_attention.py | 554 ++++++++---------- 1 file changed, 249 insertions(+), 305 deletions(-) diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 023185d429b2..1ad7a80eb5e7 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -15,299 +15,223 @@ this function is modified from https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 """ - - @triton.jit - def _context_flash_attention_kernel( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, - TMP, - alibi_ptr, - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_tmp_b, - stride_tmp_h, - stride_tmp_s, - # suggtest set-up 64, 128, 256, 512 - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - batch_id = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - - # get batch info - cur_batch_seq_len = tl.load(B_Seqlen + batch_id) - cur_batch_start_index = tl.load(B_Start_Loc + batch_id) - block_start_loc = BLOCK_M * start_m - - load_p_ptrs = ( - Q - + (cur_batch_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd - v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd - t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s - - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - if alibi_ptr is not None: - alibi_m = tl.load(alibi_ptr + cur_head) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - k = tl.load( - k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, + if triton.__version__ < "2.1.0": + @triton.jit + def _context_flash_attention_kernel( + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, + TMP, + alibi_ptr, + Out, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_tmp_b, + stride_tmp_h, + stride_tmp_s, + # suggtest set-up 64, 128, 256, 512 + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + batch_id = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + # get batch info + cur_batch_seq_len = tl.load(B_Seqlen + batch_id) + cur_batch_start_index = tl.load(B_Start_Loc + batch_id) + block_start_loc = BLOCK_M * start_m + + load_p_ptrs = ( + Q + + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd ) + q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) if alibi_ptr is not None: - alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) - qk -= alibi_loc * alibi_m - - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - tl.store(t_ptrs, acc_scale) - acc_scale = tl.load(t_ptrs) - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, + alibi_m = tl.load(alibi_ptr + cur_head) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k = tl.load( + k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + if alibi_ptr is not None: + alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) + qk -= alibi_loc * alibi_m + + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_o = ( + (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - off_o = ( - (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return - - @triton.jit - def _bloom_fwd_kernel_2_1_0( - Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen, - Out, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd - off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd - - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - alibi_m = tl.load(Alibi + cur_head) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - - alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) - qk -= alibi_loc * alibi_m - - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return - - @triton.jit - def _llama_fwd_kernel_2_1_0( - Q, K, V, sm_scale, B_Start_Loc, B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 - Out, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, - kv_group_num, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) - # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + else: + @triton.jit + def _context_flash_attention_kernel_2( + Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen, + Out, + kv_group_num, + stride_qbs, stride_qh, stride_qd, + stride_kbs, stride_kh, stride_kd, + stride_vbs, stride_vh, stride_vd, + stride_obs, stride_oh, stride_od, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + if kv_group_num is not None: + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + if kv_group_num is None or kv_group_num == 1: + off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + else: + off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + + q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + if Alibi is not None: + alibi_m = tl.load(Alibi + cur_head) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + if Alibi is not None: + alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) + qk -= alibi_loc * alibi_m + + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return @torch.no_grad() def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None): @@ -324,10 +248,9 @@ def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, al grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) num_warps = 4 if Lk <= 64 else 8 - - tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) if triton.__version__ < "2.1.0": + tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) _context_flash_attention_kernel[grid]( q, k, @@ -361,15 +284,22 @@ def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, al num_stages=1, ) else: - - kv_group_num = q.shape[1] // k.shape[1] - _bloom_fwd_kernel_2_1_0[grid]( + _context_flash_attention_kernel_2[grid]( q, k, v, sm_scale, alibi, b_start_loc, b_seq_len, o, - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), + None, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, @@ -406,7 +336,7 @@ def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): b_start_loc, b_seq_len, tmp, - None, + None, o, q.stride(0), q.stride(1), @@ -431,14 +361,28 @@ def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): ) else: kv_group_num = q.shape[1] // k.shape[1] - _llama_fwd_kernel_2_1_0[grid]( - q, k, v, sm_scale, b_start_loc, b_seq_len, + _context_flash_attention_kernel_2[grid]( + q, + k, + v, + sm_scale, + None, + b_start_loc, + b_seq_len, o, - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), - kv_group_num=kv_group_num, + kv_group_num, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, From 9dc997313576de01f8a94f15a810eec9abf0de00 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 15 Nov 2023 11:15:57 +0800 Subject: [PATCH 03/14] fix --- requirements/requirements-test.txt | 2 +- requirements/requirements.txt | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index f54b13c7e43c..61b58055e666 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -12,7 +12,7 @@ torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package torchrec==0.2.0 contexttimer einops -triton==2.0.0.dev20221202 +triton==2.1.0 requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 SentencePiece ninja diff --git a/requirements/requirements.txt b/requirements/requirements.txt index f3d58ea72c60..095617d76355 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -16,4 +16,3 @@ ray sentencepiece google protobuf -pytest From 2a532abda61a5310974c4b92099d7da474e1248f Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 15 Nov 2023 14:37:43 +0800 Subject: [PATCH 04/14] reset --- colossalai/inference/README.md | 2 +- .../tensor_parallel/modeling/llama.py | 54 +++++++++---------- .../kernel/triton/token_attention_kernel.py | 28 +++++----- 3 files changed, 40 insertions(+), 44 deletions(-) diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index cf5dbf245205..8d0e8eb614cf 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -66,7 +66,7 @@ flash-attention # install lightllm since we depend on lightllm triton kernels git clone https://github.com/ModelTC/lightllm cd lightllm -git checkout 28c1267cfca536b7b4f28e921e03de735b003039 +git checkout 46c0a80914cda789a15407712bbdb9972336771d pip3 install -e . # also, install xformers from source: diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 62c2aad3c055..9146ad098401 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -12,11 +12,8 @@ from ._utils import copy_kv_to_mem_cache try: - from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd as lightllm_llama2_context_attention_fwd, - ) from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd as lightllm_context_attention_fwd, + context_attention_fwd as lightllm_llama_context_attention_fwd, ) from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd @@ -56,32 +53,20 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): def llama_triton_context_attention( query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1 ): - if num_key_value_groups == 1: - if HAS_LIGHTLLM_KERNEL is False: - llama_context_attn_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, - infer_state.max_len_in_batch, - ) - else: - lightllm_context_attention_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, - infer_state.max_len_in_batch, - ) + # if num_key_value_groups == 1: + if HAS_LIGHTLLM_KERNEL is False: + llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) else: - assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model" - lightllm_llama2_context_attention_fwd( + lightllm_llama_context_attention_fwd( query_states, key_states, value_states, @@ -91,6 +76,17 @@ def llama_triton_context_attention( # infer_state.cache_manager.past_key_values_length, infer_state.max_len_in_batch, ) + # else: + # llama_context_attn_fwd( + # query_states, + # key_states, + # value_states, + # attn_output, + # infer_state.start_loc, + # infer_state.seq_len, + # # infer_state.cache_manager.past_key_values_length, + # infer_state.max_len_in_batch, + # ) def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1): diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index 8dc919bad125..3cf77e0fccf6 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -13,17 +13,17 @@ print("please install triton from https://github.com/openai/triton") try: - from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import ( - token_att_fwd as lightllm_llama2_token_att_fwd, - ) - from lightllm.models.llama2.triton_kernel.token_attention_nopad_reduceV import ( - token_att_fwd2 as lightllm_llama2_token_att_fwd2, - ) - from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import ( - token_softmax_fwd as lightllm_llama2_token_softmax_fwd, - ) + # from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import ( + # token_att_fwd as lightllm_llama2_token_att_fwd, + # ) + # from lightllm.models.llama2.triton_kernel.token_attention_nopad_reduceV import ( + # token_att_fwd2 as lightllm_llama2_token_att_fwd2, + # ) + # from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import ( + # token_softmax_fwd as lightllm_llama2_token_softmax_fwd, + # ) - from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fw2 + from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fwd2 from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd @@ -72,7 +72,7 @@ def token_attention_fwd( lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) att_m_tensor = None - lightllm_llama_token_att_fw2( + lightllm_llama_token_att_fwd2( prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch ) prob = None @@ -203,7 +203,7 @@ def token_attn( calcu_shape1 = (batch_size, head_num, head_dim) att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - lightllm_llama2_token_att_fwd( + lightllm_llama_token_att_fwd( q, k, att_m_tensor, @@ -215,12 +215,12 @@ def token_attn( if triton.__version__ == "2.0.0": prob = torch.empty_like(att_m_tensor) - lightllm_llama2_token_softmax_fwd( + lightllm_llama_token_softmax_fwd( att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch ) att_m_tensor = None - lightllm_llama2_token_att_fwd2( + lightllm_llama_token_att_fwd2( prob, v, attn_out.view(calcu_shape1), From 00eccf620a35f8758087c1c218d68e6a00652ab6 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 15 Nov 2023 14:52:46 +0800 Subject: [PATCH 05/14] add build script --- colossalai/inference/build.sh | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 colossalai/inference/build.sh diff --git a/colossalai/inference/build.sh b/colossalai/inference/build.sh new file mode 100644 index 000000000000..e9506cc107c0 --- /dev/null +++ b/colossalai/inference/build.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash + +# install triton +pip install triton +pip install transformers + +# install lightllm +mkdir 3rdParty +cd 3rdParty +git clone https://github.com/ModelTC/lightllm +cd lightllm +git checkout 28c1267cfca536b7b4f28e921e03de735b003039 +pip install -e . +cd ../../ + + + From ab9d3993225959251964c303b73867c8d706b11d Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 15 Nov 2023 14:55:46 +0800 Subject: [PATCH 06/14] add building process --- colossalai/inference/README.md | 17 +++++++++-------- colossalai/inference/build.sh | 9 ++++++++- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index cf5dbf245205..ce9b6658b955 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -69,11 +69,11 @@ cd lightllm git checkout 28c1267cfca536b7b4f28e921e03de735b003039 pip3 install -e . -# also, install xformers from source: -pip install ninja -# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types -pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers +# install flash-attention +git clone -recursive https://github.com/Dao-AILab/flash-attention +cd flash-attention +pip install -e . ``` ### Docker @@ -95,10 +95,11 @@ cd lightllm git checkout 28c1267cfca536b7b4f28e921e03de735b003039 pip3 install -e . -# install xformers from source -pip install ninja -# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types -pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers +# install flash-attention +git clone -recursive https://github.com/Dao-AILab/flash-attention +cd flash-attention +pip install -e . + ``` ### Dive into fast-inference! diff --git a/colossalai/inference/build.sh b/colossalai/inference/build.sh index e9506cc107c0..6a73f6f0b985 100644 --- a/colossalai/inference/build.sh +++ b/colossalai/inference/build.sh @@ -4,14 +4,21 @@ pip install triton pip install transformers -# install lightllm +# install lightllm and flash-attention mkdir 3rdParty cd 3rdParty git clone https://github.com/ModelTC/lightllm cd lightllm git checkout 28c1267cfca536b7b4f28e921e03de735b003039 pip install -e . +cd .. + +git clone -recursive https://github.com/Dao-AILab/flash-attention +cd flash-attention +pip install -e . + cd ../../ + From 95bc86e2eefa1bb6215d0b28b04fb219f440a48c Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 15 Nov 2023 15:00:29 +0800 Subject: [PATCH 07/14] add llama2 exmaple --- examples/inference/colossal_llama2_test.py | 81 ++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 examples/inference/colossal_llama2_test.py diff --git a/examples/inference/colossal_llama2_test.py b/examples/inference/colossal_llama2_test.py new file mode 100644 index 000000000000..299fb679982d --- /dev/null +++ b/examples/inference/colossal_llama2_test.py @@ -0,0 +1,81 @@ +import os +import warnings + +import pytest +import torch +import torch.distributed as dist +from packaging import version +from transformers import LlamaForCausalLM, LlamaTokenizer +from transformers import BloomForCausalLM, BloomTokenizerFast + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + +# MODEL_PATH = "/home/lclcq/share/models--bigscience--bloom-560m/snapshots/4f42c91d806a19ae1a46af6c3fb5f4990d884cd6" +MODEL_PATH = "/home/lclcq/share/llama-7b" + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +TPSIZE = 1 +BATCH_SIZE = 4 +MAX_INPUT_LEN = 32 +MAX_OUTPUT_LEN = 128 + +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') + + +@parameterize('test_config', [{ + 'tp_size': TPSIZE, +}]) +def run_llama_test(test_config): + + model_path = MODEL_PATH + if os.path.isdir(model_path) is False: + warnings.warn("Model path does not exist") + return + + # tokenizer = BloomTokenizerFast.from_pretrained(model_path) + tokenizer = LlamaTokenizer.from_pretrained(model_path) + tokenizer.pad_token_id = tokenizer.unk_token_id + # model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) + model = LlamaForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) + model = model.half() + + text = ["Introduce London.", "What is the genus of Poodle?"] + input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True) + + print(input_ids) + + shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, + inference_only=True) + infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + + generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) + outputs = infer_engine.generate(input_ids, **generate_kwargs) + + assert outputs is not None + + if not dist.is_initialized() or dist.get_rank() == 0: + for o in outputs: + output_text = tokenizer.decode(o) + print(output_text) + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_test() + + +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, TPSIZE) + + +if __name__ == "__main__": + test_llama() From 511df424fb07ec31d50a41ff88ed13cee700bc89 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 15 Nov 2023 16:18:27 +0800 Subject: [PATCH 08/14] add colossal-llama2 test --- .../tensor_parallel/modeling/llama.py | 11 ----- examples/inference/bench_llama.py | 1 - examples/inference/colossal_llama2_test.py | 41 +++++++++++-------- 3 files changed, 24 insertions(+), 29 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 9146ad098401..abb9e415ac50 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -76,17 +76,6 @@ def llama_triton_context_attention( # infer_state.cache_manager.past_key_values_length, infer_state.max_len_in_batch, ) - # else: - # llama_context_attn_fwd( - # query_states, - # key_states, - # value_states, - # attn_output, - # infer_state.start_loc, - # infer_state.seq_len, - # # infer_state.cache_manager.past_key_values_length, - # infer_state.max_len_in_batch, - # ) def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1): diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index 4db32c71af30..c6eb3a5c68e4 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -28,7 +28,6 @@ def run_llama_test(args): tokenizer.pad_token_id = tokenizer.unk_token_id model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) model = model.half() - model.config shard_config = ShardConfig( enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True} diff --git a/examples/inference/colossal_llama2_test.py b/examples/inference/colossal_llama2_test.py index 299fb679982d..379cfb882724 100644 --- a/examples/inference/colossal_llama2_test.py +++ b/examples/inference/colossal_llama2_test.py @@ -4,6 +4,7 @@ import pytest import torch import torch.distributed as dist +import argparse from packaging import version from transformers import LlamaForCausalLM, LlamaTokenizer from transformers import BloomForCausalLM, BloomTokenizerFast @@ -13,9 +14,10 @@ from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from transformers import AutoModelForCausalLM, AutoTokenizer -# MODEL_PATH = "/home/lclcq/share/models--bigscience--bloom-560m/snapshots/4f42c91d806a19ae1a46af6c3fb5f4990d884cd6" -MODEL_PATH = "/home/lclcq/share/llama-7b" +# # MODEL_PATH = "/home/lclcq/share/models--bigscience--bloom-560m/snapshots/4f42c91d806a19ae1a46af6c3fb5f4990d884cd6" +# MODEL_PATH = "/home/lclcq/share/llama-7b" os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' TPSIZE = 1 @@ -29,18 +31,13 @@ @parameterize('test_config', [{ 'tp_size': TPSIZE, }]) -def run_llama_test(test_config): +def run_llama_test(test_config, args): - model_path = MODEL_PATH - if os.path.isdir(model_path) is False: - warnings.warn("Model path does not exist") - return + model_path = args.path - # tokenizer = BloomTokenizerFast.from_pretrained(model_path) - tokenizer = LlamaTokenizer.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer.pad_token_id = tokenizer.unk_token_id - # model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) - model = LlamaForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) + model = AutoModelForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) model = model.half() text = ["Introduce London.", "What is the genus of Poodle?"] @@ -49,7 +46,7 @@ def run_llama_test(test_config): print(input_ids) shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, - inference_only=True) + extra_kwargs={"inference_only": True}) infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) @@ -63,19 +60,29 @@ def run_llama_test(test_config): print(output_text) -def check_llama(rank, world_size, port): +def check_llama(rank, world_size, port, args): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_llama_test() + run_llama_test(args=args) @pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() -def test_llama(): - spawn(check_llama, TPSIZE) +def test_llama(args): + spawn(check_llama, args.tp_size, args=args) if __name__ == "__main__": - test_llama() + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--path", type=str, default = "hpcai-tech/Colossal-LLaMA-2-7b-base", help="Model path") + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=32, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") + parser.add_argument( + "--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"] + ) + args = parser.parse_args() + test_llama(args) From 939f3aa708dba663ef89f2b3c7755be63a0ab095 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 15 Nov 2023 16:29:17 +0800 Subject: [PATCH 09/14] clean --- colossalai/inference/README.md | 2 +- colossalai/kernel/triton/token_attention_kernel.py | 10 ---------- examples/inference/colossal_llama2_test.py | 4 ---- 3 files changed, 1 insertion(+), 15 deletions(-) diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 25e6e44933dc..ce9b6658b955 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -66,7 +66,7 @@ flash-attention # install lightllm since we depend on lightllm triton kernels git clone https://github.com/ModelTC/lightllm cd lightllm -git checkout 46c0a80914cda789a15407712bbdb9972336771d +git checkout 28c1267cfca536b7b4f28e921e03de735b003039 pip3 install -e . diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index 3cf77e0fccf6..de2003748e65 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -13,16 +13,6 @@ print("please install triton from https://github.com/openai/triton") try: - # from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import ( - # token_att_fwd as lightllm_llama2_token_att_fwd, - # ) - # from lightllm.models.llama2.triton_kernel.token_attention_nopad_reduceV import ( - # token_att_fwd2 as lightllm_llama2_token_att_fwd2, - # ) - # from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import ( - # token_softmax_fwd as lightllm_llama2_token_softmax_fwd, - # ) - from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fwd2 from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd diff --git a/examples/inference/colossal_llama2_test.py b/examples/inference/colossal_llama2_test.py index 379cfb882724..4651aa7c288f 100644 --- a/examples/inference/colossal_llama2_test.py +++ b/examples/inference/colossal_llama2_test.py @@ -6,8 +6,6 @@ import torch.distributed as dist import argparse from packaging import version -from transformers import LlamaForCausalLM, LlamaTokenizer -from transformers import BloomForCausalLM, BloomTokenizerFast import colossalai from colossalai.inference.tensor_parallel.engine import TPInferEngine @@ -16,8 +14,6 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from transformers import AutoModelForCausalLM, AutoTokenizer -# # MODEL_PATH = "/home/lclcq/share/models--bigscience--bloom-560m/snapshots/4f42c91d806a19ae1a46af6c3fb5f4990d884cd6" -# MODEL_PATH = "/home/lclcq/share/llama-7b" os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' TPSIZE = 1 From a3564b4935615182b40dac13f06d87c9c9aa59d2 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 16 Nov 2023 11:06:05 +0800 Subject: [PATCH 10/14] fall back test setting --- colossalai/inference/tensor_parallel/modeling/llama.py | 2 ++ tests/test_infer_ops/triton/test_llama_context_attention.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index abb9e415ac50..55854fb13e25 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -8,6 +8,7 @@ from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards +from lightllm.models.llama.triton_kernel.flash_decoding import token_decode_attention_flash_decoding from ._utils import copy_kv_to_mem_cache @@ -92,6 +93,7 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key # infer_state.cache_manager.past_key_values_length, infer_state.max_len_in_batch, ) + else: Llama2TokenAttentionForwards.token_attn( query_states, diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py index 4250bed3dad5..556f419cd428 100644 --- a/tests/test_infer_ops/triton/test_llama_context_attention.py +++ b/tests/test_infer_ops/triton/test_llama_context_attention.py @@ -43,7 +43,7 @@ def test_llama_context_attention(): torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) assert torch.allclose( - torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2 + torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3 ), "outputs from triton and torch are not matched" From 1d0d5fe6ce2ff4b84f2ba7e01054dea7d4f4b1eb Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 16 Nov 2023 12:46:41 +0800 Subject: [PATCH 11/14] fix test file --- .../{colossal_llama2_test.py => colossal_llama2_demo.py} | 2 -- 1 file changed, 2 deletions(-) rename examples/inference/{colossal_llama2_test.py => colossal_llama2_demo.py} (95%) diff --git a/examples/inference/colossal_llama2_test.py b/examples/inference/colossal_llama2_demo.py similarity index 95% rename from examples/inference/colossal_llama2_test.py rename to examples/inference/colossal_llama2_demo.py index 4651aa7c288f..ad944cc1af37 100644 --- a/examples/inference/colossal_llama2_test.py +++ b/examples/inference/colossal_llama2_demo.py @@ -62,8 +62,6 @@ def check_llama(rank, world_size, port, args): run_llama_test(args=args) -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") -@pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_llama(args): From ed53588fc04da1da54ace7ddea61e376603b721b Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 16 Nov 2023 13:37:08 +0800 Subject: [PATCH 12/14] clean --- examples/inference/colossal_llama2_demo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/inference/colossal_llama2_demo.py b/examples/inference/colossal_llama2_demo.py index ad944cc1af37..72abab2a4eba 100644 --- a/examples/inference/colossal_llama2_demo.py +++ b/examples/inference/colossal_llama2_demo.py @@ -1,7 +1,6 @@ import os import warnings -import pytest import torch import torch.distributed as dist import argparse From 76f1268bd573394e8e30155899e5e6b610cc38c6 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 16 Nov 2023 13:39:03 +0800 Subject: [PATCH 13/14] clean --- tests/test_infer_ops/triton/test_llama_context_attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py index 556f419cd428..95fe50cf1d9c 100644 --- a/tests/test_infer_ops/triton/test_llama_context_attention.py +++ b/tests/test_infer_ops/triton/test_llama_context_attention.py @@ -41,7 +41,6 @@ def test_llama_context_attention(): llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len) torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - assert torch.allclose( torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3 ), "outputs from triton and torch are not matched" From d7a93367705fd8c44cb29dd466fa290c828ecb69 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 16 Nov 2023 13:40:54 +0800 Subject: [PATCH 14/14] clean --- colossalai/inference/tensor_parallel/modeling/llama.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 55854fb13e25..448943b12c9e 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -8,10 +8,7 @@ from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards -from lightllm.models.llama.triton_kernel.flash_decoding import token_decode_attention_flash_decoding - from ._utils import copy_kv_to_mem_cache - try: from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( context_attention_fwd as lightllm_llama_context_attention_fwd,