From 3d641621c00ce7f358ea2d849bc67b02fb432bff Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 24 Aug 2023 10:52:58 +0800 Subject: [PATCH 1/6] add token forward --- .../kernel/triton/token_attention_kernel.py | 340 ++++++++++++++++++ .../test_kernels/triton/test_token_attn_1.py | 104 ++++++ .../test_kernels/triton/test_token_attn_2.py | 56 +++ .../triton/test_token_attn_fwd.py | 64 ++++ .../test_kernels/triton/test_token_softmax.py | 32 ++ 5 files changed, 596 insertions(+) create mode 100644 colossalai/kernel/triton/token_attention_kernel.py create mode 100644 tests/test_kernels/triton/test_token_attn_1.py create mode 100644 tests/test_kernels/triton/test_token_attn_2.py create mode 100644 tests/test_kernels/triton/test_token_attn_fwd.py create mode 100644 tests/test_kernels/triton/test_token_softmax.py diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py new file mode 100644 index 000000000000..c5872772e642 --- /dev/null +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -0,0 +1,340 @@ +import math + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, attn_out, + kv_cache_loc_b_stride, kv_cache_loc_s_stride, q_batch_stride, q_head_stride, q_head_dim_stride, + k_batch_stride, k_head_stride, k_head_dim_stride, attn_head_stride, attn_batch_stride, + HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + start_n = tl.program_id(2) + + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = max_kv_cache_len + + off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + q = tl.load(Q + off_q + start_mark) + offs_n_new = current_batch_start_index + offs_n + k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0) + off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride + k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride + tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) + return + + +@triton.jit +def _token_attn_1_alibi_kernel( + Q, + K, + sm_scale, + alibi, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, # kv_cache_start_loc 保存的是如果连续存储时候的累加输入和 + attn_out, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + q_batch_stride, + q_head_stride, + q_head_dim_stride, + k_batch_stride, + k_head_stride, + k_head_dim_stride, + attn_head_stride, + attn_batch_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + start_n = tl.program_id(2) + + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = max_kv_cache_len + + off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): # 用来判断当前 mask 是否需要计算 + alibi_m = tl.load(alibi + current_head) + q = tl.load(Q + off_q + start_mark) + offs_n_new = current_batch_start_index + offs_n + k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0) + off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride + k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n) + off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride + tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) + return + + +@torch.no_grad() +def token_attn_fwd_1(q, k, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, alibi=None): + BLOCK = 32 + # shape constraints + q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] + assert q_head_dim == k_head_dim + assert k_head_dim in {16, 32, 64, 128} + sm_scale = 1.0 / (k_head_dim**0.5) + + batch, head_num = kv_cache_loc.shape[0], q.shape[1] + + grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK)) + + num_warps = 4 if k_head_dim <= 64 else 8 + num_warps = 2 + + if alibi is not None: + _token_attn_1_alibi_kernel[grid]( + q, + k, + sm_scale, + alibi, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + attn_out.stride(0), + attn_out.stride(1), + HEAD_DIM=k_head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + else: + _token_attn_1_kernel[grid]( + q, + k, + sm_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + attn_out.stride(0), + attn_out.stride(1), + HEAD_DIM=k_head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + +@triton.jit +def _token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, + logics_head_dim_stride, logics_batch_stride, prob_head_dim_stride, prob_batch_stride, + BLOCK_SIZE: tl.constexpr): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + + col_offsets = tl.arange(0, BLOCK_SIZE) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + row = tl.load(softmax_logics + current_head * logics_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, + mask=col_offsets < current_batch_seq_len, + other=-float('inf')).to(tl.float32) + + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + + tl.store(softmax_prob_out + current_head * prob_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, + softmax_output, + mask=col_offsets < current_batch_seq_len) + return + + +@torch.no_grad() +def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len): + BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len) + batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0] + + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + + _token_attn_softmax_fwd[(batch, head_num)]( + softmax_logics, + kv_cache_start_loc, + kv_cache_seqlen, + softmax_prob_out, + softmax_logics.stride(0), + softmax_logics.stride(1), + softmax_prob_out.stride(0), + softmax_prob_out.stride(1), + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return + + +@triton.jit +def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, + kv_cache_loc_b_stride, kv_cache_loc_s_stride, prob_head_dim_stride, prob_batch_stride, + v_batch_stride, v_head_stride, v_head_dim_stride, attn_out_batch_stride, attn_out_head_stride, + attn_out_head_dim_stride, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = current_batch_seq_len + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride + p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride + v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride + + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + for start_n in range(0, current_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + p_value = tl.load(Prob + p_offs + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0) + v_loc = tl.load(kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0) + v_value = tl.load(V + v_offs + v_loc[:, None] * v_batch_stride, + mask=(start_n + offs_n[:, None]) < current_batch_seq_len, + other=0.0) + acc += tl.sum(p_value[:, None] * v_value, 0) + + acc = acc.to(tl.float16) + off_o = current_batch * attn_out_batch_stride + current_head * attn_out_head_stride + offs_d * attn_out_head_dim_stride + out_ptrs = attn_out + off_o + tl.store(out_ptrs, acc) + return + + +@torch.no_grad() +def token_attn_fwd_2(prob, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len): + if triton.__version__ >= "2.1.0": + BLOCK = 128 + else: + BLOCK = 64 + batch, head = kv_cache_loc.shape[0], v.shape[1] + grid = (batch, head) + num_warps = 4 + dim = v.shape[-1] + + _token_attn_2_kernel[grid]( + prob, + v, + attn_out, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + prob.stride(0), + prob.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + attn_out.stride(0), + attn_out.stride(1), + attn_out.stride(2), + BLOCK_DMODEL=dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + +@torch.no_grad() +def token_attention_fwd(q, + k, + v, + attn_out, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + alibi=None): + head_num = k.shape[1] + batch_size = kv_cache_seq_len.shape[0] + calcu_shape1 = (batch_size, head_num, k.shape[2]) + total_token_num = k.shape[0] + + att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") + + token_attn_fwd_1(q.view(calcu_shape1), + k, + att_m_tensor, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + alibi=alibi) + + prob = torch.empty_like(att_m_tensor) + + token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) + att_m_tensor = None + token_attn_fwd_2(prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, + max_len_in_batch) + + prob = None + + return diff --git a/tests/test_kernels/triton/test_token_attn_1.py b/tests/test_kernels/triton/test_token_attn_1.py new file mode 100644 index 000000000000..dcfb28bd9ee0 --- /dev/null +++ b/tests/test_kernels/triton/test_token_attn_1.py @@ -0,0 +1,104 @@ +import math + +import torch + +from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_1 + + +def torch_attn(xq, xk, bs, seqlen, num_head, head_dim): + xq = xq.view(bs, 1, num_head, head_dim) + xk = xk.view(bs, seqlen, num_head, head_dim) + keys = xk + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + scores = (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape( + num_head, -1) + # print("s ", scores.shape) + return scores + + +def torch_attn_1(xq, xk, seqlen, num_head, head_dim): + xq = xq.view(1, num_head, head_dim) + xk = xk.view(seqlen, num_head, head_dim) + logics = torch.sum(xq * xk, dim=-1, keepdim=False) + + logics = logics.transpose(0, 1) / math.sqrt(head_dim) + return logics + + +def test_attn_1(): + import time + + batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128 + + dtype = torch.float16 + + q = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + k = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + attn_out = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda") + + # print(attn_out) + + b_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda") + kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") + kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") + + for i in range(batch_size): + kv_cache_start_loc[i] = i * seq_len + kv_cache_seq_len[i] = seq_len + b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") + # print(b_loc[i]) + + # Warm up + for _ in range(10): + token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) + run_iter = 1000 + torch.cuda.synchronize() + t1 = time.time() + for _ in range(run_iter): + token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) + torch.cuda.synchronize() + t2 = time.time() + print("Time cost {}".format((t2 - t1) / run_iter)) + + torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze() + o = attn_out.squeeze() + print("max ", torch.max(torch.abs(torch_out - o))) + print("mean ", torch.mean(torch.abs(torch_out - o))) + assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) + + +# def test_alibi_attn_1(): +# import torch + +# batch_size, seq_len, head_num, head_dim = 2, 1025, 12, 128 + +# dtype = torch.float16 + +# q = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) +# k = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) +# attn_out = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda") + +# # print(attn_out) + +# b_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda") +# kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") +# kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") + +# for i in range(batch_size): +# kv_cache_start_loc[i] = i * seq_len +# kv_cache_seq_len[i] = seq_len +# b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") +# # print(b_loc[i]) + +# token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) + +# torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze() +# o = attn_out.squeeze() +# print("max ", torch.max(torch.abs(torch_out - o))) +# print("mean ", torch.mean(torch.abs(torch_out - o))) +# assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) + +if __name__ == "__main__": + test_attn_1() + test_alibi_attn_1() diff --git a/tests/test_kernels/triton/test_token_attn_2.py b/tests/test_kernels/triton/test_token_attn_2.py new file mode 100644 index 000000000000..0bb67b5f718e --- /dev/null +++ b/tests/test_kernels/triton/test_token_attn_2.py @@ -0,0 +1,56 @@ +import math + +import torch + +from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_2 + + +def torch_attn(V, P, bs, seqlen, num_head, head_dim): + V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2) + P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1) + attn_out = torch.matmul(P, V) + + return attn_out + + +def test_token_attn_2(): + import time + + batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128 + dtype = torch.float16 + + V = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=10) + Prob = torch.empty( + (head_num, batch_size * seq_len), dtype=dtype, + device="cuda").normal_(mean=0.4, std=0.2).reshape(head_num, batch_size, + seq_len).softmax(-1).reshape(head_num, batch_size * seq_len) + attn_out = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda") + + kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") + kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") + kv_cache_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda") + for i in range(batch_size): + kv_cache_start_loc[i] = i * seq_len + kv_cache_seq_len[i] = seq_len + kv_cache_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") + + # Warm up + for _ in range(10): + token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) + run_iter = 1000 + torch.cuda.synchronize() + t1 = time.time() + for _ in range(run_iter): + token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) + torch.cuda.synchronize() + t2 = time.time() + print("Time cost {}".format((t2 - t1) / run_iter)) + torch_out = torch_attn(V, Prob, batch_size, seq_len, head_num, head_dim).squeeze() + o = attn_out + print("max ", torch.max(torch.abs(torch_out - o))) + print("mean ", torch.mean(torch.abs(torch_out - o))) + assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) + + +if __name__ == "__main__": + test_token_attn_2() diff --git a/tests/test_kernels/triton/test_token_attn_fwd.py b/tests/test_kernels/triton/test_token_attn_fwd.py new file mode 100644 index 000000000000..3949c60baac4 --- /dev/null +++ b/tests/test_kernels/triton/test_token_attn_fwd.py @@ -0,0 +1,64 @@ +import time + +import torch + +from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd + + +def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): + xq = xq.view(bs, 1, num_head, head_dim) + xk = xk.view(bs, seqlen, num_head, head_dim) + xv = xv.view(bs, seqlen, num_head, head_dim) + + logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5) + prob = torch.softmax(logics, dim=1) + prob = prob.view(bs, seqlen, num_head, 1) + + return torch.sum(prob * xv, dim=1, keepdim=False) + + +def test(): + + Z, head_num, seq_len, head_dim = 22, 112 // 8, 2048, 128 + dtype = torch.float16 + q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) + v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) + o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) + alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda") + + max_kv_cache_len = seq_len + kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") + kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda") + kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") + + kv_cache_seq_len[:] = seq_len + kv_cache_start_loc[0] = 0 + kv_cache_start_loc[1] = seq_len + kv_cache_start_loc[2] = 2 * seq_len + kv_cache_start_loc[3] = 3 * seq_len + + for i in range(Z): + kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda") + + token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi) + torch.cuda.synchronize() + start = time.time() + token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi) + torch.cuda.synchronize() + print("cost time:", (time.time() - start) * 1000) + + torch_att(q, k, v, Z, seq_len, head_num, head_dim) + torch.cuda.synchronize() + start = time.time() + torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim) + torch.cuda.synchronize() + print("cost time:", (time.time() - start) * 1000) + + print("max ", torch.max(torch.abs(torch_out - o))) + print("mean ", torch.mean(torch.abs(torch_out - o))) + assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) + + +if __name__ == "__main__": + test() diff --git a/tests/test_kernels/triton/test_token_softmax.py b/tests/test_kernels/triton/test_token_softmax.py new file mode 100644 index 000000000000..cc1d0b213afa --- /dev/null +++ b/tests/test_kernels/triton/test_token_softmax.py @@ -0,0 +1,32 @@ +from colossalai.kernel.triton.token_attention_kernel import token_attn_softmax_fwd + + +def test_softmax(): + + import torch + + batch_size, seq_len, head_num, head_dim = 4, 1025, 12, 128 + + dtype = torch.float16 + + Logics = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.1, std=10) + ProbOut = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) + + kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") + kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") + + for i in range(batch_size): + kv_cache_start_loc[i] = i * seq_len + kv_cache_seq_len[i] = seq_len + + token_attn_softmax_fwd(Logics, kv_cache_start_loc, kv_cache_seq_len, ProbOut, seq_len) + + torch_out = Logics.reshape(head_num * batch_size, -1).softmax(-1).reshape(head_num, batch_size * seq_len) + o = ProbOut + print("max ", torch.max(torch.abs(torch_out - o))) + print("mean ", torch.mean(torch.abs(torch_out - o))) + assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) + + +if __name__ == "__main__": + test_softmax() From 6a434be6fac1de0abfdd802f21d28a5805f6e44f Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 24 Aug 2023 11:05:37 +0800 Subject: [PATCH 2/6] fix tests --- colossalai/kernel/triton/token_attention_kernel.py | 8 ++++---- tests/test_kernels/triton/test_token_attn_1.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index c5872772e642..e39f4007967b 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -229,12 +229,12 @@ def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, kv_cache_loc_b_stride, kv_cache_loc_s_stride, prob_head_dim_stride, prob_batch_stride, v_batch_stride, v_head_stride, v_head_dim_stride, attn_out_batch_stride, attn_out_head_stride, - attn_out_head_dim_stride, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr): + attn_out_head_dim_stride, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr): current_batch = tl.program_id(0) current_head = tl.program_id(1) offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_d = tl.arange(0, HEAD_DIM) current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) current_batch_start_index = max_kv_cache_len - current_batch_seq_len current_batch_end_index = current_batch_seq_len @@ -244,7 +244,7 @@ def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) for start_n in range(0, current_batch_seq_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) p_value = tl.load(Prob + p_offs + start_n * kv_cache_loc_s_stride, @@ -294,7 +294,7 @@ def token_attn_fwd_2(prob, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cac attn_out.stride(0), attn_out.stride(1), attn_out.stride(2), - BLOCK_DMODEL=dim, + HEAD_DIM=dim, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1, diff --git a/tests/test_kernels/triton/test_token_attn_1.py b/tests/test_kernels/triton/test_token_attn_1.py index dcfb28bd9ee0..138e0394af77 100644 --- a/tests/test_kernels/triton/test_token_attn_1.py +++ b/tests/test_kernels/triton/test_token_attn_1.py @@ -101,4 +101,4 @@ def test_attn_1(): if __name__ == "__main__": test_attn_1() - test_alibi_attn_1() + # test_alibi_attn_1() From a6d4c0ebc63823e6e2c1bd5806eb3638fb3e0fbf Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 24 Aug 2023 11:18:39 +0800 Subject: [PATCH 3/6] fix comments --- .../kernel/triton/token_attention_kernel.py | 28 ++++--------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index e39f4007967b..3d48f44a49e0 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -44,28 +44,10 @@ def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_ca @triton.jit -def _token_attn_1_alibi_kernel( - Q, - K, - sm_scale, - alibi, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, # kv_cache_start_loc 保存的是如果连续存储时候的累加输入和 - attn_out, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - q_batch_stride, - q_head_stride, - q_head_dim_stride, - k_batch_stride, - k_head_stride, - k_head_dim_stride, - attn_head_stride, - attn_batch_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr): +def _token_attn_1_alibi_kernel(Q, K, sm_scale, alibi, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, + max_kv_cache_len, attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride, q_batch_stride, + q_head_stride, q_head_dim_stride, k_batch_stride, k_head_stride, k_head_dim_stride, + attn_head_stride, attn_batch_stride, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr): current_batch = tl.program_id(0) current_head = tl.program_id(1) start_n = tl.program_id(2) @@ -84,7 +66,7 @@ def _token_attn_1_alibi_kernel( block_stard_index = start_n * BLOCK_N block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) - for start_mark in range(0, block_mask, 1): # 用来判断当前 mask 是否需要计算 + for start_mark in range(0, block_mask, 1): alibi_m = tl.load(alibi + current_head) q = tl.load(Q + off_q + start_mark) offs_n_new = current_batch_start_index + offs_n From 17ff93eb1161f2c34cd04db85a17022fcd155b7e Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 24 Aug 2023 11:31:57 +0800 Subject: [PATCH 4/6] add try import triton --- .../kernel/triton/token_attention_kernel.py | 615 +++++++++--------- 1 file changed, 312 insertions(+), 303 deletions(-) diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index 3d48f44a49e0..ee4b87332d7a 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -1,322 +1,331 @@ import math import torch -import triton -import triton.language as tl - - -@triton.jit -def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, attn_out, - kv_cache_loc_b_stride, kv_cache_loc_s_stride, q_batch_stride, q_head_stride, q_head_dim_stride, - k_batch_stride, k_head_stride, k_head_dim_stride, attn_head_stride, attn_batch_stride, - HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - start_n = tl.program_id(2) - - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_end_index = max_kv_cache_len - - off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) - - for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) - offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0) - off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride - k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride - tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) - return - - -@triton.jit -def _token_attn_1_alibi_kernel(Q, K, sm_scale, alibi, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, - max_kv_cache_len, attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride, q_batch_stride, - q_head_stride, q_head_dim_stride, k_batch_stride, k_head_stride, k_head_dim_stride, - attn_head_stride, attn_batch_stride, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - start_n = tl.program_id(2) - - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_end_index = max_kv_cache_len - - off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) - - for start_mark in range(0, block_mask, 1): - alibi_m = tl.load(alibi + current_head) - q = tl.load(Q + off_q + start_mark) - offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0) - off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride - k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n) - off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride - tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) - return - - -@torch.no_grad() -def token_attn_fwd_1(q, k, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, alibi=None): - BLOCK = 32 - # shape constraints - q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] - assert q_head_dim == k_head_dim - assert k_head_dim in {16, 32, 64, 128} - sm_scale = 1.0 / (k_head_dim**0.5) - - batch, head_num = kv_cache_loc.shape[0], q.shape[1] - - grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK)) - - num_warps = 4 if k_head_dim <= 64 else 8 - num_warps = 2 - - if alibi is not None: - _token_attn_1_alibi_kernel[grid]( - q, - k, - sm_scale, - alibi, - kv_cache_loc, + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + + @triton.jit + def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, + attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride, q_batch_stride, q_head_stride, + q_head_dim_stride, k_batch_stride, k_head_stride, k_head_dim_stride, attn_head_stride, + attn_batch_stride, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + start_n = tl.program_id(2) + + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = max_kv_cache_len + + off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + q = tl.load(Q + off_q + start_mark) + offs_n_new = current_batch_start_index + offs_n + k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0) + off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride + k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride + tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) + return + + @triton.jit + def _token_attn_1_alibi_kernel(Q, K, sm_scale, alibi, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, + max_kv_cache_len, attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride, + q_batch_stride, q_head_stride, q_head_dim_stride, k_batch_stride, k_head_stride, + k_head_dim_stride, attn_head_stride, attn_batch_stride, HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + start_n = tl.program_id(2) + + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = max_kv_cache_len + + off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + alibi_m = tl.load(alibi + current_head) + q = tl.load(Q + off_q + start_mark) + offs_n_new = current_batch_start_index + offs_n + k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0) + off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride + k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n) + off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride + tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) + return + + @torch.no_grad() + def token_attn_fwd_1(q, + k, + attn_out, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + alibi=None): + BLOCK = 32 + # shape constraints + q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] + assert q_head_dim == k_head_dim + assert k_head_dim in {16, 32, 64, 128} + sm_scale = 1.0 / (k_head_dim**0.5) + + batch, head_num = kv_cache_loc.shape[0], q.shape[1] + + grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK)) + + num_warps = 4 if k_head_dim <= 64 else 8 + num_warps = 2 + + if alibi is not None: + _token_attn_1_alibi_kernel[grid]( + q, + k, + sm_scale, + alibi, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + attn_out.stride(0), + attn_out.stride(1), + HEAD_DIM=k_head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + else: + _token_attn_1_kernel[grid]( + q, + k, + sm_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + attn_out.stride(0), + attn_out.stride(1), + HEAD_DIM=k_head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @triton.jit + def _token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, + logics_head_dim_stride, logics_batch_stride, prob_head_dim_stride, prob_batch_stride, + BLOCK_SIZE: tl.constexpr): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + + col_offsets = tl.arange(0, BLOCK_SIZE) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + row = tl.load(softmax_logics + current_head * logics_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, + mask=col_offsets < current_batch_seq_len, + other=-float('inf')).to(tl.float32) + + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + + tl.store(softmax_prob_out + current_head * prob_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, + softmax_output, + mask=col_offsets < current_batch_seq_len) + return + + @torch.no_grad() + def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len): + BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len) + batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0] + + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + + _token_attn_softmax_fwd[(batch, head_num)]( + softmax_logics, kv_cache_start_loc, kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - attn_out.stride(0), - attn_out.stride(1), - HEAD_DIM=k_head_dim, - BLOCK_N=BLOCK, + softmax_prob_out, + softmax_logics.stride(0), + softmax_logics.stride(1), + softmax_prob_out.stride(0), + softmax_prob_out.stride(1), num_warps=num_warps, - num_stages=1, + BLOCK_SIZE=BLOCK_SIZE, ) - else: - _token_attn_1_kernel[grid]( - q, - k, - sm_scale, + return + + @triton.jit + def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, + kv_cache_loc_b_stride, kv_cache_loc_s_stride, prob_head_dim_stride, prob_batch_stride, + v_batch_stride, v_head_stride, v_head_dim_stride, attn_out_batch_stride, + attn_out_head_stride, attn_out_head_dim_stride, HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = current_batch_seq_len + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride + p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride + v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride + + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + for start_n in range(0, current_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + p_value = tl.load(Prob + p_offs + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0) + v_loc = tl.load(kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0) + v_value = tl.load(V + v_offs + v_loc[:, None] * v_batch_stride, + mask=(start_n + offs_n[:, None]) < current_batch_seq_len, + other=0.0) + acc += tl.sum(p_value[:, None] * v_value, 0) + + acc = acc.to(tl.float16) + off_o = current_batch * attn_out_batch_stride + current_head * attn_out_head_stride + offs_d * attn_out_head_dim_stride + out_ptrs = attn_out + off_o + tl.store(out_ptrs, acc) + return + + @torch.no_grad() + def token_attn_fwd_2(prob, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len): + if triton.__version__ >= "2.1.0": + BLOCK = 128 + else: + BLOCK = 64 + batch, head = kv_cache_loc.shape[0], v.shape[1] + grid = (batch, head) + num_warps = 4 + dim = v.shape[-1] + + _token_attn_2_kernel[grid]( + prob, + v, + attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, - attn_out, kv_cache_loc.stride(0), kv_cache_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), + prob.stride(0), + prob.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), attn_out.stride(0), attn_out.stride(1), - HEAD_DIM=k_head_dim, + attn_out.stride(2), + HEAD_DIM=dim, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1, ) - return - - -@triton.jit -def _token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, - logics_head_dim_stride, logics_batch_stride, prob_head_dim_stride, prob_batch_stride, - BLOCK_SIZE: tl.constexpr): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - - col_offsets = tl.arange(0, BLOCK_SIZE) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - row = tl.load(softmax_logics + current_head * logics_head_dim_stride + - (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, - mask=col_offsets < current_batch_seq_len, - other=-float('inf')).to(tl.float32) - - row_minus_max = row - tl.max(row, axis=0) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - - tl.store(softmax_prob_out + current_head * prob_head_dim_stride + - (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, - softmax_output, - mask=col_offsets < current_batch_seq_len) - return - - -@torch.no_grad() -def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len): - BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len) - batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0] - - num_warps = 4 - if BLOCK_SIZE >= 2048: - num_warps = 8 - if BLOCK_SIZE >= 4096: - num_warps = 16 - - _token_attn_softmax_fwd[(batch, head_num)]( - softmax_logics, - kv_cache_start_loc, - kv_cache_seqlen, - softmax_prob_out, - softmax_logics.stride(0), - softmax_logics.stride(1), - softmax_prob_out.stride(0), - softmax_prob_out.stride(1), - num_warps=num_warps, - BLOCK_SIZE=BLOCK_SIZE, - ) - return - - -@triton.jit -def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, - kv_cache_loc_b_stride, kv_cache_loc_s_stride, prob_head_dim_stride, prob_batch_stride, - v_batch_stride, v_head_stride, v_head_dim_stride, attn_out_batch_stride, attn_out_head_stride, - attn_out_head_dim_stride, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_end_index = current_batch_seq_len - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride - p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride - v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride - - acc = tl.zeros([HEAD_DIM], dtype=tl.float32) - for start_n in range(0, current_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load(Prob + p_offs + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0) - v_loc = tl.load(kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0) - v_value = tl.load(V + v_offs + v_loc[:, None] * v_batch_stride, - mask=(start_n + offs_n[:, None]) < current_batch_seq_len, - other=0.0) - acc += tl.sum(p_value[:, None] * v_value, 0) - - acc = acc.to(tl.float16) - off_o = current_batch * attn_out_batch_stride + current_head * attn_out_head_stride + offs_d * attn_out_head_dim_stride - out_ptrs = attn_out + off_o - tl.store(out_ptrs, acc) - return - - -@torch.no_grad() -def token_attn_fwd_2(prob, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len): - if triton.__version__ >= "2.1.0": - BLOCK = 128 - else: - BLOCK = 64 - batch, head = kv_cache_loc.shape[0], v.shape[1] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - - _token_attn_2_kernel[grid]( - prob, - v, - attn_out, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - attn_out.stride(0), - attn_out.stride(1), - attn_out.stride(2), - HEAD_DIM=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -@torch.no_grad() -def token_attention_fwd(q, - k, - v, - attn_out, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - alibi=None): - head_num = k.shape[1] - batch_size = kv_cache_seq_len.shape[0] - calcu_shape1 = (batch_size, head_num, k.shape[2]) - total_token_num = k.shape[0] - - att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - - token_attn_fwd_1(q.view(calcu_shape1), - k, - att_m_tensor, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - alibi=alibi) - - prob = torch.empty_like(att_m_tensor) - - token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) - att_m_tensor = None - token_attn_fwd_2(prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, - max_len_in_batch) - - prob = None - - return + return + + @torch.no_grad() + def token_attention_fwd(q, + k, + v, + attn_out, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + alibi=None): + head_num = k.shape[1] + batch_size = kv_cache_seq_len.shape[0] + calcu_shape1 = (batch_size, head_num, k.shape[2]) + total_token_num = k.shape[0] + + att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") + + token_attn_fwd_1(q.view(calcu_shape1), + k, + att_m_tensor, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + alibi=alibi) + + prob = torch.empty_like(att_m_tensor) + + token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) + att_m_tensor = None + token_attn_fwd_2(prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, + max_len_in_batch) + + prob = None + + return From ad62b63d61b91c9de902a0e9e0f84df96c263635 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 24 Aug 2023 11:43:15 +0800 Subject: [PATCH 5/6] add adapted license --- colossalai/kernel/triton/token_attention_kernel.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index ee4b87332d7a..c6b25f4abcec 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -1,3 +1,5 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm + import math import torch From 294c431e90e00ad047c962f7f12426248f265fa4 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 24 Aug 2023 14:17:58 +0800 Subject: [PATCH 6/6] add tests check --- .../test_kernels/triton/test_token_attn_1.py | 20 ++++++++++++++----- .../test_kernels/triton/test_token_attn_2.py | 16 ++++++++++++++- .../triton/test_token_attn_fwd.py | 16 ++++++++++++++- .../test_kernels/triton/test_token_softmax.py | 18 ++++++++++++++++- 4 files changed, 62 insertions(+), 8 deletions(-) diff --git a/tests/test_kernels/triton/test_token_attn_1.py b/tests/test_kernels/triton/test_token_attn_1.py index 138e0394af77..ba236de82498 100644 --- a/tests/test_kernels/triton/test_token_attn_1.py +++ b/tests/test_kernels/triton/test_token_attn_1.py @@ -1,8 +1,20 @@ import math +import pytest import torch +from packaging import version -from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_1 +try: + import triton + import triton.language as tl + + from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_1 + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') def torch_attn(xq, xk, bs, seqlen, num_head, head_dim): @@ -13,7 +25,6 @@ def torch_attn(xq, xk, bs, seqlen, num_head, head_dim): keys = keys.transpose(1, 2) scores = (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape( num_head, -1) - # print("s ", scores.shape) return scores @@ -26,6 +37,8 @@ def torch_attn_1(xq, xk, seqlen, num_head, head_dim): return logics +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, + reason="triton requires cuda version to be higher than 11.4") def test_attn_1(): import time @@ -37,8 +50,6 @@ def test_attn_1(): k = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) attn_out = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda") - # print(attn_out) - b_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda") kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") @@ -47,7 +58,6 @@ def test_attn_1(): kv_cache_start_loc[i] = i * seq_len kv_cache_seq_len[i] = seq_len b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") - # print(b_loc[i]) # Warm up for _ in range(10): diff --git a/tests/test_kernels/triton/test_token_attn_2.py b/tests/test_kernels/triton/test_token_attn_2.py index 0bb67b5f718e..36b517c4aa3b 100644 --- a/tests/test_kernels/triton/test_token_attn_2.py +++ b/tests/test_kernels/triton/test_token_attn_2.py @@ -1,8 +1,20 @@ import math +import pytest import torch +from packaging import version -from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_2 +try: + import triton + import triton.language as tl + + from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_2 + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') def torch_attn(V, P, bs, seqlen, num_head, head_dim): @@ -13,6 +25,8 @@ def torch_attn(V, P, bs, seqlen, num_head, head_dim): return attn_out +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, + reason="triton requires cuda version to be higher than 11.4") def test_token_attn_2(): import time diff --git a/tests/test_kernels/triton/test_token_attn_fwd.py b/tests/test_kernels/triton/test_token_attn_fwd.py index 3949c60baac4..e765ed4a3415 100644 --- a/tests/test_kernels/triton/test_token_attn_fwd.py +++ b/tests/test_kernels/triton/test_token_attn_fwd.py @@ -1,8 +1,20 @@ import time +import pytest import torch +from packaging import version -from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd +try: + import triton + import triton.language as tl + + from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): @@ -17,6 +29,8 @@ def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): return torch.sum(prob * xv, dim=1, keepdim=False) +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, + reason="triton requires cuda version to be higher than 11.4") def test(): Z, head_num, seq_len, head_dim = 22, 112 // 8, 2048, 128 diff --git a/tests/test_kernels/triton/test_token_softmax.py b/tests/test_kernels/triton/test_token_softmax.py index cc1d0b213afa..08ffe1ca8323 100644 --- a/tests/test_kernels/triton/test_token_softmax.py +++ b/tests/test_kernels/triton/test_token_softmax.py @@ -1,6 +1,22 @@ -from colossalai.kernel.triton.token_attention_kernel import token_attn_softmax_fwd +import pytest +import torch +from packaging import version +try: + import triton + import triton.language as tl + from colossalai.kernel.triton.token_attention_kernel import token_attn_softmax_fwd + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, + reason="triton requires cuda version to be higher than 11.4") def test_softmax(): import torch