From f98a882a67aeaecd8dd6ee40feff789b80834211 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Sun, 16 Jul 2023 14:53:17 +0800 Subject: [PATCH 01/22] added softmax kernel --- colossalai/kernel/triton/softmax_kernel.py | 39 ++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 colossalai/kernel/triton/softmax_kernel.py diff --git a/colossalai/kernel/triton/softmax_kernel.py b/colossalai/kernel/triton/softmax_kernel.py new file mode 100644 index 000000000000..2c0f9d7b84d2 --- /dev/null +++ b/colossalai/kernel/triton/softmax_kernel.py @@ -0,0 +1,39 @@ +import triton +import triton.language as tl +''' +softmax kernel is modified based on +https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py +''' + + +@triton.jit +def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr): + r""" the kernel function for implementing softmax operator + Args: + output_ptr: the output after finishing softmax operation, (N, hidden_dim) + input_ptr: the tensor of input, shape should be (N, hidden_dim) + n_cols(tl.constexpr): the number of cols of input + BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim + """ + row_idx = tl.program_id(0) + row_start_ptr = input_ptr + row_idx * row_stride + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32) + row_minus_max = row - tl.max(row, axis=0) + + if mask_ptr is not None: + # load mask into SRAM + mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets + mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32) + + # update + row_minus_max = row_minus_max + mask + + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + output_row_start_ptr = output_ptr + row_idx * row_stride + output_ptrs = output_row_start_ptr + col_offsets + # Write back output to DRAM + tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) \ No newline at end of file From f15b661f9af36b0069c3309a4806743437af5725 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Sun, 16 Jul 2023 14:58:39 +0800 Subject: [PATCH 02/22] added qkv_kernel --- colossalai/kernel/triton/qkv_matmul_kernel.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 colossalai/kernel/triton/qkv_matmul_kernel.py diff --git a/colossalai/kernel/triton/qkv_matmul_kernel.py b/colossalai/kernel/triton/qkv_matmul_kernel.py new file mode 100644 index 000000000000..f913928c3b0d --- /dev/null +++ b/colossalai/kernel/triton/qkv_matmul_kernel.py @@ -0,0 +1,105 @@ +import torch +import triton +import triton.language as tl + +from inference.ops.triton.k_activations import leaky_relu, relu + + +''' +this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html +''' +@triton.jit +def qkv_gemm_4d_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_ab, + stride_ah, + stride_am, + stride_ak, + stride_bb, + stride_bh, + stride_bk, + stride_bn, + stride_cb, + stride_ch, + stride_cm, + stride_cn, + scale, + # Meta-parameters + BLOCK_SIZE_M : tl.constexpr = 64, + BLOCK_SIZE_N : tl.constexpr = 32, + BLOCK_SIZE_K : tl.constexpr = 32, + GROUP_SIZE_M : tl.constexpr = 8, +): + r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer, + where score_matrix is softmax(Q*V^T/sqrt(hidden_size)) + Args: + a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K) + b_ptr(torch.Tensor): pointer to input tensor array (bs, N, h, K) or (bs, h, N, K) + c_ptr(torch.Tensor): pointer to output tensor array (bs, M, h, N) or (bs, h, M, N) + stride_ab(tl.constexpr): stride for bs-dimention for tensor array A + stride_ah(tl.constexpr): stride for h-dimention for tensor array A + stride_am(tl.constexpr): stride for m-dimention for tensor array A + stride_ak(tl.constexpr): stride for k-dimention for tensor array A + stride_bb(tl.constexpr): stride for bs-dimention for tensor array B + stride_bh(tl.constexpr): stride for h-dimention for tensor array B + stride_bk(tl.constexpr): stride for k-dimention for tensor array B + stride_bn(tl.constexpr): stride for n-dimention for tensor array B + stride_cb(tl.constexpr): stride for bs-dimention for tensor array output + stride_ch(tl.constexpr): stride for h-dimention for tensor array output + stride_cm(tl.constexpr): stride for m-dimention for tensor array output + stride_cn(tl.constexpr): stride for n-dimention for tensor array output + BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a + BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b + BLOCK_SIZE_K : tiling size for K-dimension of a and b + GROUP_SIZE_M : group size for reducing cache miss, more details: + """ + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + batch = tl.program_id(axis = 0) + head = tl.program_id(axis = 1) + pid = tl.program_id(axis = 2) + + # the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah + + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)) + b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K) + b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=a_mask, other=0.) + b = tl.load(b_ptrs, mask=b_mask, other=0.) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + accumulator = accumulator.to(c_ptr.dtype.element_ty) + if scale > 0: + accumulator = accumulator * scale.to(c_ptr.dtype.element_ty) + + + offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] + + stride_cn * offs_accumu_n[None, :]) + accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N) + tl.store(c_ptrs, accumulator, mask=accumulator_mask) From 04bd4287392eb16e844402fb858f001eefe2ce5b Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Sun, 16 Jul 2023 15:07:50 +0800 Subject: [PATCH 03/22] added ops --- colossalai/kernel/triton/ops.py | 161 ++++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 colossalai/kernel/triton/ops.py diff --git a/colossalai/kernel/triton/ops.py b/colossalai/kernel/triton/ops.py new file mode 100644 index 000000000000..3d0eac1c71fc --- /dev/null +++ b/colossalai/kernel/triton/ops.py @@ -0,0 +1,161 @@ +import torch +from torch import nn + +import triton +import triton.language as tl + +from .qkv_matmul_kernel import qkv_gemm_4d_kernel +from .softmax_kernel import softmax_kernel + +def self_attention_forward_with_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: float): + + # TODO: call flash attention kernel implemetation (@cuiqing.li (tiandiao123) works on) + pass + + +def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float): + r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels + Args: + q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) + k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) + v (torch.Tensor): V embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) + input_mask (torch.Tensor): mask for softmax layer, shape should be (batch, num_heads, seq_lem, seq_len) + scale: the float scale value which is used to multiply with Q*K^T before doing softmax + + Return: + output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size) + """ + assert len(q.shape) == 4, "the shape of q val must be 4" + batches, M, H, K = q.shape + assert q.shape == k.shape, "the shape of q and the shape of k must be equal" + assert q.shape == v.shape, "the shape of q and the shape of v must be equal" + assert q.shape[-1] == k.shape[-1], "the last dimension of q and k must be equal" + + N = k.shape[1] + + # head_size * num_of_head + d_model = q.shape[-1] * q.shape[-2] + + score_output = torch.empty( + (batches, H, M, N), device=q.device, dtype=q.dtype) + + grid = lambda meta: ( + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * + triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + qkv_gemm_4d_kernel[grid]( + q, k, score_output, + M, N, K, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(3), k.stride(1), + score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3), + scale=scale, + # currently manually setting, later on we can use auto-tune config to match best setting + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=32, + BLOCK_SIZE_K=32, + GROUP_SIZE_M=8, + ) + + softmax_output = torch.empty( + score_output.shape, device=score_output.device, dtype=score_output.dtype) + score_output_shape = score_output.shape + + score_output = score_output.view(-1, score_output.shape[-1]) + n_rows, n_cols = score_output.shape + + if n_rows <= 350000: + + block_size = max(triton.next_power_of_2(n_cols), 2) + num_warps = 4 + if block_size >= 4096: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + else: + num_warps = 4 + + + softmax_kernel[(n_rows, )]( + softmax_output, + score_output, + score_output.stride(0), + n_cols, + mask_ptr = input_mask, + num_warps=num_warps, + BLOCK_SIZE=block_size, + ) + + else: + #TODO: change softmax kernel functions to make it suitable for large size dimension + softmax_output = torch.nn.functional.softmax(score_output, dim=-1) + softmax_output = softmax_output.view(*score_output_shape) + + batches, H, M, K = softmax_output.shape + N = v.shape[-1] + + output = torch.empty( + (batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype) + + grid = lambda meta: ( + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * + triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + qkv_gemm_4d_kernel[grid]( + softmax_output, v, output, + M, N, K, + softmax_output.stride(0), + softmax_output.stride(1), + softmax_output.stride(2), + softmax_output.stride(3), + v.stride(0), + v.stride(2), + v.stride(1), + v.stride(3), + output.stride(0), + output.stride(2), + output.stride(1), + output.stride(3), + BLOCK_SIZE_M=128, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=64, + GROUP_SIZE_M=8, + scale=-1, + ) + return output.view(batches, -1, d_model) + + +def self_attention_compute_using_triton(qkv, + input_mask, + layer_past, + alibi, + scale, + head_size, + triangular=False, + use_flash=False): + + assert qkv.is_contiguous() + assert alibi is None, "current triton self-attention does not support alibi" + batches = qkv.shape[0] + d_model = qkv.shape[-1] // 3 + num_of_heads = d_model // head_size + + q = qkv[:, :, :d_model] + k = qkv[:, :, d_model:d_model * 2] + v = qkv[:, :, d_model * 2:] + q = q.view(batches, -1, num_of_heads, head_size) + k = k.view(batches, -1, num_of_heads, head_size) + v = v.view(batches, -1, num_of_heads, head_size) + if use_flash: + data_output_triton = self_attention_forward_with_fusion(q, k, v, input_mask, scale) + else: + data_output_triton = self_attention_forward_without_fusion( + q, k, v, input_mask, scale) + + return data_output_triton From 754d6e9f84125af816717c6b7370b5cf49930409 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Sun, 16 Jul 2023 15:12:50 +0800 Subject: [PATCH 04/22] adding tests --- tests/test_kernels/test_self_attention.py | 2 ++ tests/test_kernels/test_softmax.py | 11 +++++++++++ 2 files changed, 13 insertions(+) create mode 100644 tests/test_kernels/test_self_attention.py create mode 100644 tests/test_kernels/test_softmax.py diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/test_self_attention.py new file mode 100644 index 000000000000..61cb712d8f94 --- /dev/null +++ b/tests/test_kernels/test_self_attention.py @@ -0,0 +1,2 @@ +import torch +from torch import nn \ No newline at end of file diff --git a/tests/test_kernels/test_softmax.py b/tests/test_kernels/test_softmax.py new file mode 100644 index 000000000000..6dfd9415d661 --- /dev/null +++ b/tests/test_kernels/test_softmax.py @@ -0,0 +1,11 @@ +import pytest +import torch +from torch import nn + +def test_softmax_op(): + device = "cuda" + data = torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32) + + +if __name__ == "__main__": + test_softmax_op() \ No newline at end of file From 978de66d62058f6708b2855383123b29db0351b2 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Sun, 16 Jul 2023 16:10:12 +0800 Subject: [PATCH 05/22] upload tets --- tests/test_kernels/test_softmax.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/test_kernels/test_softmax.py b/tests/test_kernels/test_softmax.py index 6dfd9415d661..86ce8bd0a485 100644 --- a/tests/test_kernels/test_softmax.py +++ b/tests/test_kernels/test_softmax.py @@ -2,10 +2,21 @@ import torch from torch import nn +from colossalai.kernel.triton.ops import softmax + def test_softmax_op(): device = "cuda" - data = torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32) + data_samples = [torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32), + torch.randn((320, 320, 78), device = "cuda", dtype = torch.float32), + ] + for data in data_samples: + data = torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32) + module = nn.Softmax(dim = -1) + data_torch_out = module(data) + + data_triton_out = softmax(data) + print(torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-4, atol=1e-4)) + - if __name__ == "__main__": test_softmax_op() \ No newline at end of file From 8390da3c63c0102433f16512492b629c425300de Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Sun, 16 Jul 2023 16:12:20 +0800 Subject: [PATCH 06/22] fix tests --- tests/test_kernels/test_softmax.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_kernels/test_softmax.py b/tests/test_kernels/test_softmax.py index 86ce8bd0a485..a47c1bb6228e 100644 --- a/tests/test_kernels/test_softmax.py +++ b/tests/test_kernels/test_softmax.py @@ -6,16 +6,17 @@ def test_softmax_op(): device = "cuda" - data_samples = [torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32), - torch.randn((320, 320, 78), device = "cuda", dtype = torch.float32), + data_samples = [ + torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32), + torch.randn((320, 320, 78), device = "cuda", dtype = torch.float32), + torch.randn((2345, 4, 5, 64), device = "cuda", dtype = torch.float16) ] + for data in data_samples: - data = torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32) module = nn.Softmax(dim = -1) data_torch_out = module(data) - data_triton_out = softmax(data) - print(torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-4, atol=1e-4)) + print(torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-3, atol=1e-3)) if __name__ == "__main__": From dbea4f98f1d7b1955a0af94bf96fe0feaf33b748 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Sun, 16 Jul 2023 17:56:56 +0800 Subject: [PATCH 07/22] debugging --- tests/test_kernels/test_self_attention.py | 73 ++++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/test_self_attention.py index 61cb712d8f94..a1a63cd9906c 100644 --- a/tests/test_kernels/test_self_attention.py +++ b/tests/test_kernels/test_self_attention.py @@ -1,2 +1,73 @@ import torch -from torch import nn \ No newline at end of file +from torch import nn +import torch.nn.functional as F + +from colossalai.kernel.triton.ops import self_attention_compute_using_triton + +def self_attention_compute_using_torch(qkv, + input_mask, + scale, + head_size + ): + batches = qkv.shape[0] + d_model = qkv.shape[-1] // 3 + num_of_heads = d_model // head_size + + q = qkv[:, :, :d_model] + k = qkv[:, :, d_model:d_model * 2] + v = qkv[:, :, d_model * 2:] + q = q.view(batches, -1, num_of_heads, head_size) + k = k.view(batches, -1, num_of_heads, head_size) + v = v.view(batches, -1, num_of_heads, head_size) + + q = torch.transpose(q, 1, 2) + k = torch.transpose(k, 1, 2) + v = torch.transpose(v, 1, 2) + + k = torch.transpose(k, -1, -2) + + score_output = torch.einsum('bnij,bnjk->bnik', q, k) + score_output *= scale + + score_output = F.softmax(score_output, dim = -1) + res = torch.einsum('bnij,bnjk->bnik', score_output, v) + res = torch.transpose(res, 1, 3) + res = res.contiguous() + + + return res.view(batches, -1, d_model) + + + + + +def test_self_atttention_test(): + qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float32) + data_output_torch = self_attention_compute_using_torch( + qkv.clone(), + input_mask = None, + scale = 1.2, + head_size = 32 + ) + + data_output_triton = self_attention_compute_using_triton( + qkv.clone(), + alibi=None, + head_size=32, + scale=1.2, + input_mask=None, + layer_past=None, + use_flash=False, + triangular=True) + print(data_output_torch.shape) + print(data_output_triton.shape) + print(data_output_torch) + print(data_output_triton) + exit(0) + print(torch.allclose(data_output_torch.cpu(), data_output_triton.cpu(), rtol=1e-3, atol=1e-3)) + + + + +if __name__ == "__main__": + test_self_atttention_test() \ No newline at end of file From 9dfe619c1d7af74756a21700b51598520ca3618f Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 17 Jul 2023 12:14:41 +0800 Subject: [PATCH 08/22] debugging tests --- tests/test_kernels/test_self_attention.py | 62 ++++++++++++++++++++++- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/test_self_attention.py index a1a63cd9906c..f24eb7e76963 100644 --- a/tests/test_kernels/test_self_attention.py +++ b/tests/test_kernels/test_self_attention.py @@ -1,8 +1,10 @@ +import triton import torch from torch import nn import torch.nn.functional as F from colossalai.kernel.triton.ops import self_attention_compute_using_triton +from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel def self_attention_compute_using_torch(qkv, input_mask, @@ -37,8 +39,63 @@ def self_attention_compute_using_torch(qkv, return res.view(batches, -1, d_model) +def test_qkv_matmul(): + qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float32) + scale = 1.2 + head_size = 32 + batches = qkv.shape[0] + d_model = qkv.shape[-1] // 3 + num_of_heads = d_model // head_size + + q = qkv[:, :, :d_model] + k = qkv[:, :, d_model:d_model * 2] + + q = q.view(batches, -1, num_of_heads, head_size) + k = k.view(batches, -1, num_of_heads, head_size) + q_copy = q.clone() + k_copy = k.clone() + q = torch.transpose(q, 1, 2).contiguous() + k = torch.transpose(k, 1, 2).contiguous() + k = torch.transpose(k, 2, 3).contiguous() + + print(q.shape) + print(k.shape) + + torch_ouput = torch.einsum('bnij,bnjk->bnik', q, k) + torch_ouput *= 1.2 + + q, k = q_copy, k_copy + batches, M, H, K = q.shape + N = k.shape[1] + score_output = torch.empty( + (batches, H, M, N), device=q.device, dtype=q.dtype) + + grid = lambda meta: ( + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * + triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + K = q.shape[3] + qkv_gemm_4d_kernel[grid]( + q, k, score_output, + M, N, K, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(3), k.stride(1), + score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3), + scale=scale, + # currently manually setting, later on we can use auto-tune config to match best setting + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=32, + BLOCK_SIZE_K=32, + GROUP_SIZE_M=8, + ) + + check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-2) + assert check is True, "the outputs of triton and torch are not matched" + - def test_self_atttention_test(): @@ -70,4 +127,5 @@ def test_self_atttention_test(): if __name__ == "__main__": - test_self_atttention_test() \ No newline at end of file + test_qkv_matmul() + # test_self_atttention_test() \ No newline at end of file From a24dbd669b693a59ee664144cae5acf5f14a3491 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 17 Jul 2023 13:26:07 +0800 Subject: [PATCH 09/22] debugging --- tests/test_kernels/test_self_attention.py | 47 ++++++++++++----------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/test_self_attention.py index f24eb7e76963..7ab4cd4fdf5d 100644 --- a/tests/test_kernels/test_self_attention.py +++ b/tests/test_kernels/test_self_attention.py @@ -22,25 +22,25 @@ def self_attention_compute_using_torch(qkv, k = k.view(batches, -1, num_of_heads, head_size) v = v.view(batches, -1, num_of_heads, head_size) - q = torch.transpose(q, 1, 2) - k = torch.transpose(k, 1, 2) - v = torch.transpose(v, 1, 2) + q = torch.transpose(q, 1, 2).contiguous() + k = torch.transpose(k, 1, 2).contiguous() + v = torch.transpose(v, 1, 2).contiguous() - k = torch.transpose(k, -1, -2) + k = torch.transpose(k, -1, -2).contiguous() score_output = torch.einsum('bnij,bnjk->bnik', q, k) score_output *= scale - score_output = F.softmax(score_output, dim = -1) - res = torch.einsum('bnij,bnjk->bnik', score_output, v) + softmax_output = F.softmax(score_output, dim = -1) + res = torch.einsum('bnij,bnjk->bnik', softmax_output, v) res = torch.transpose(res, 1, 3) res = res.contiguous() - return res.view(batches, -1, d_model) + return res.view(batches, -1, d_model), score_output, softmax_output def test_qkv_matmul(): - qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float32) + qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) scale = 1.2 head_size = 32 batches = qkv.shape[0] @@ -92,22 +92,21 @@ def test_qkv_matmul(): GROUP_SIZE_M=8, ) - check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-2) + check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5) + print(check) assert check is True, "the outputs of triton and torch are not matched" - - def test_self_atttention_test(): - qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float32) - data_output_torch = self_attention_compute_using_torch( + qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) + data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch( qkv.clone(), input_mask = None, scale = 1.2, head_size = 32 ) - data_output_triton = self_attention_compute_using_triton( + data_output_triton, score_output_triton, softmax_output_triton = self_attention_compute_using_triton( qkv.clone(), alibi=None, head_size=32, @@ -116,16 +115,20 @@ def test_self_atttention_test(): layer_past=None, use_flash=False, triangular=True) - print(data_output_torch.shape) - print(data_output_triton.shape) - print(data_output_torch) - print(data_output_triton) - exit(0) - print(torch.allclose(data_output_torch.cpu(), data_output_triton.cpu(), rtol=1e-3, atol=1e-3)) + # print(data_output_torch.shape) + # print(data_output_triton.shape) + # print(data_output_torch) + # print(data_output_triton) + # print(torch.allclose(data_output_torch.cpu(), data_output_triton.cpu(), rtol=1e-2, atol=1e-2)) + + print(score_output_torch.shape) + print("**********") + print(score_output_triton.shape) + print(torch.allclose(data_output_torch.cpu(), data_output_triton.cpu(), rtol=1e-2, atol=1e-2)) if __name__ == "__main__": - test_qkv_matmul() - # test_self_atttention_test() \ No newline at end of file + # test_qkv_matmul() + test_self_atttention_test() \ No newline at end of file From 90678d4ac80983b636f975f0b8d1ba821dca3d94 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 17 Jul 2023 14:03:21 +0800 Subject: [PATCH 10/22] added --- tests/test_kernels/test_self_attention.py | 83 ++++++++++------------- 1 file changed, 37 insertions(+), 46 deletions(-) diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/test_self_attention.py index 7ab4cd4fdf5d..5787f7d5edd0 100644 --- a/tests/test_kernels/test_self_attention.py +++ b/tests/test_kernels/test_self_attention.py @@ -6,38 +6,6 @@ from colossalai.kernel.triton.ops import self_attention_compute_using_triton from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel -def self_attention_compute_using_torch(qkv, - input_mask, - scale, - head_size - ): - batches = qkv.shape[0] - d_model = qkv.shape[-1] // 3 - num_of_heads = d_model // head_size - - q = qkv[:, :, :d_model] - k = qkv[:, :, d_model:d_model * 2] - v = qkv[:, :, d_model * 2:] - q = q.view(batches, -1, num_of_heads, head_size) - k = k.view(batches, -1, num_of_heads, head_size) - v = v.view(batches, -1, num_of_heads, head_size) - - q = torch.transpose(q, 1, 2).contiguous() - k = torch.transpose(k, 1, 2).contiguous() - v = torch.transpose(v, 1, 2).contiguous() - - k = torch.transpose(k, -1, -2).contiguous() - - score_output = torch.einsum('bnij,bnjk->bnik', q, k) - score_output *= scale - - softmax_output = F.softmax(score_output, dim = -1) - res = torch.einsum('bnij,bnjk->bnik', softmax_output, v) - res = torch.transpose(res, 1, 3) - res = res.contiguous() - - - return res.view(batches, -1, d_model), score_output, softmax_output def test_qkv_matmul(): qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) @@ -58,9 +26,6 @@ def test_qkv_matmul(): k = torch.transpose(k, 1, 2).contiguous() k = torch.transpose(k, 2, 3).contiguous() - print(q.shape) - print(k.shape) - torch_ouput = torch.einsum('bnij,bnjk->bnik', q, k) torch_ouput *= 1.2 @@ -93,10 +58,43 @@ def test_qkv_matmul(): ) check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5) - print(check) assert check is True, "the outputs of triton and torch are not matched" +def self_attention_compute_using_torch(qkv, + input_mask, + scale, + head_size + ): + batches = qkv.shape[0] + d_model = qkv.shape[-1] // 3 + num_of_heads = d_model // head_size + + q = qkv[:, :, :d_model] + k = qkv[:, :, d_model:d_model * 2] + v = qkv[:, :, d_model * 2:] + q = q.view(batches, -1, num_of_heads, head_size) + k = k.view(batches, -1, num_of_heads, head_size) + v = v.view(batches, -1, num_of_heads, head_size) + + q = torch.transpose(q, 1, 2).contiguous() + k = torch.transpose(k, 1, 2).contiguous() + v = torch.transpose(v, 1, 2).contiguous() + + k = torch.transpose(k, -1, -2).contiguous() + + score_output = torch.einsum('bnij,bnjk->bnik', q, k) + score_output *= scale + + softmax_output = F.softmax(score_output, dim = -1) + res = torch.einsum('bnij,bnjk->bnik', softmax_output, v) + res = torch.transpose(res, 1, 2) + res = res.contiguous() + + + return res.view(batches, -1, d_model), score_output, softmax_output + + def test_self_atttention_test(): qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch( @@ -115,20 +113,13 @@ def test_self_atttention_test(): layer_past=None, use_flash=False, triangular=True) - # print(data_output_torch.shape) - # print(data_output_triton.shape) - # print(data_output_torch) - # print(data_output_triton) - # print(torch.allclose(data_output_torch.cpu(), data_output_triton.cpu(), rtol=1e-2, atol=1e-2)) - print(score_output_torch.shape) - print("**********") - print(score_output_triton.shape) - print(torch.allclose(data_output_torch.cpu(), data_output_triton.cpu(), rtol=1e-2, atol=1e-2)) + check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2) + assert check is True, "the triton ouput is not matched with torch output" if __name__ == "__main__": - # test_qkv_matmul() + test_qkv_matmul() test_self_atttention_test() \ No newline at end of file From 11af697b796ba3880acf3078130af0ac2a5ac999 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 17 Jul 2023 14:05:56 +0800 Subject: [PATCH 11/22] fixed errors --- tests/test_kernels/test_self_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/test_self_attention.py index 5787f7d5edd0..5151c0c14a6c 100644 --- a/tests/test_kernels/test_self_attention.py +++ b/tests/test_kernels/test_self_attention.py @@ -104,7 +104,7 @@ def test_self_atttention_test(): head_size = 32 ) - data_output_triton, score_output_triton, softmax_output_triton = self_attention_compute_using_triton( + data_output_triton = self_attention_compute_using_triton( qkv.clone(), alibi=None, head_size=32, @@ -114,7 +114,7 @@ def test_self_atttention_test(): use_flash=False, triangular=True) - check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2) + check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-3, atol=1e-2) assert check is True, "the triton ouput is not matched with torch output" From 31f7de5fa98bb3477d62c1f21249cff30801f7b5 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 17 Jul 2023 14:19:26 +0800 Subject: [PATCH 12/22] added softmax kernel --- colossalai/kernel/triton/ops.py | 53 ++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/colossalai/kernel/triton/ops.py b/colossalai/kernel/triton/ops.py index 3d0eac1c71fc..60f34175db73 100644 --- a/colossalai/kernel/triton/ops.py +++ b/colossalai/kernel/triton/ops.py @@ -78,7 +78,6 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t else: num_warps = 4 - softmax_kernel[(n_rows, )]( softmax_output, score_output, @@ -152,6 +151,7 @@ def self_attention_compute_using_triton(qkv, q = q.view(batches, -1, num_of_heads, head_size) k = k.view(batches, -1, num_of_heads, head_size) v = v.view(batches, -1, num_of_heads, head_size) + if use_flash: data_output_triton = self_attention_forward_with_fusion(q, k, v, input_mask, scale) else: @@ -159,3 +159,54 @@ def self_attention_compute_using_triton(qkv, q, k, v, input_mask, scale) return data_output_triton + + +def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor: + if mask is not None: + assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask" + assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention" + + hidden_dim = input.shape[-1] + output = torch.empty_like(input) + input = input.view(-1, hidden_dim) + if mask is not None: + mask = mask.view(-1, hidden_dim) + assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same" + + num_rows, num_cols = input.shape + block_size = max(triton.next_power_of_2(num_cols), 2) + num_warps = 16 + if block_size >= 4096: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + else: + num_warps = 4 + + if num_rows <= 350000: + grid = (num_rows,) + softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps) + else: + grid = lambda meta: () + + grid = lambda meta: ( + triton.cdiv(num_rows, meta["BLOCK_M"]), + ) + + BLOCK_M = 32 + if block_size >= 4096: + BLOCK_M = 4 + elif block_size >= 2048: + BLOCK_M = 8 + + softmax_kernel_2[grid](output_ptr = output, + input_ptr = input, + row_stride = input.stride(0), + n_rows = num_rows, + n_cols = num_cols, + mask_ptr = mask, + # currently manually setting up size + BLOCK_M = 32, + BLOCK_SIZE = block_size) + + return output \ No newline at end of file From c39a2b1e624d129587ba7301c12bb95c55b65e52 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 17 Jul 2023 14:22:24 +0800 Subject: [PATCH 13/22] clean codes --- colossalai/kernel/triton/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/kernel/triton/ops.py b/colossalai/kernel/triton/ops.py index 60f34175db73..08338e3c7fda 100644 --- a/colossalai/kernel/triton/ops.py +++ b/colossalai/kernel/triton/ops.py @@ -199,7 +199,7 @@ def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Ten elif block_size >= 2048: BLOCK_M = 8 - softmax_kernel_2[grid](output_ptr = output, + softmax_kernel_2[grid](output_ptr = output, input_ptr = input, row_stride = input.stride(0), n_rows = num_rows, From d8dae228ab5b97063763d422c4145c6016f2e4a0 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 17 Jul 2023 16:15:53 +0800 Subject: [PATCH 14/22] added tests --- colossalai/kernel/triton/ops.py | 12 ++---------- tests/test_kernels/test_self_attention.py | 2 -- tests/test_kernels/test_softmax.py | 3 ++- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/colossalai/kernel/triton/ops.py b/colossalai/kernel/triton/ops.py index 08338e3c7fda..a72487e9c13c 100644 --- a/colossalai/kernel/triton/ops.py +++ b/colossalai/kernel/triton/ops.py @@ -7,11 +7,6 @@ from .qkv_matmul_kernel import qkv_gemm_4d_kernel from .softmax_kernel import softmax_kernel -def self_attention_forward_with_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: float): - - # TODO: call flash attention kernel implemetation (@cuiqing.li (tiandiao123) works on) - pass - def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float): r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels @@ -152,11 +147,8 @@ def self_attention_compute_using_triton(qkv, k = k.view(batches, -1, num_of_heads, head_size) v = v.view(batches, -1, num_of_heads, head_size) - if use_flash: - data_output_triton = self_attention_forward_with_fusion(q, k, v, input_mask, scale) - else: - data_output_triton = self_attention_forward_without_fusion( - q, k, v, input_mask, scale) + data_output_triton = self_attention_forward_without_fusion( + q, k, v, input_mask, scale) return data_output_triton diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/test_self_attention.py index 5151c0c14a6c..3230253a8bb8 100644 --- a/tests/test_kernels/test_self_attention.py +++ b/tests/test_kernels/test_self_attention.py @@ -118,8 +118,6 @@ def test_self_atttention_test(): assert check is True, "the triton ouput is not matched with torch output" - - if __name__ == "__main__": test_qkv_matmul() test_self_atttention_test() \ No newline at end of file diff --git a/tests/test_kernels/test_softmax.py b/tests/test_kernels/test_softmax.py index a47c1bb6228e..5efdbe09b39a 100644 --- a/tests/test_kernels/test_softmax.py +++ b/tests/test_kernels/test_softmax.py @@ -16,7 +16,8 @@ def test_softmax_op(): module = nn.Softmax(dim = -1) data_torch_out = module(data) data_triton_out = softmax(data) - print(torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-3, atol=1e-3)) + check = torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-3, atol=1e-3) + assert check is True, "softmax outputs from triton and torch are not matched" if __name__ == "__main__": From c3c2e2b1ef3e8eaf1d9b50f5267f3e26207670df Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 17 Jul 2023 16:20:08 +0800 Subject: [PATCH 15/22] update tests --- tests/test_kernels/test_self_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/test_self_attention.py index 3230253a8bb8..e14d2ebaa20f 100644 --- a/tests/test_kernels/test_self_attention.py +++ b/tests/test_kernels/test_self_attention.py @@ -114,7 +114,7 @@ def test_self_atttention_test(): use_flash=False, triangular=True) - check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-3, atol=1e-2) + check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2) assert check is True, "the triton ouput is not matched with torch output" From 2da38c6ca0e86f165759f5b024ca05782e669d45 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 17 Jul 2023 16:20:20 +0800 Subject: [PATCH 16/22] update tests --- colossalai/kernel/triton/qkv_matmul_kernel.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/kernel/triton/qkv_matmul_kernel.py b/colossalai/kernel/triton/qkv_matmul_kernel.py index f913928c3b0d..6789a4fa45ea 100644 --- a/colossalai/kernel/triton/qkv_matmul_kernel.py +++ b/colossalai/kernel/triton/qkv_matmul_kernel.py @@ -2,7 +2,6 @@ import triton import triton.language as tl -from inference.ops.triton.k_activations import leaky_relu, relu ''' From 76743e3c07e30dddc93821983d0d3ac260fef4b0 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 18 Jul 2023 10:46:06 +0800 Subject: [PATCH 17/22] added attention --- tests/test_kernels/test_self_attention.py | 8 ++++++++ tests/test_kernels/test_softmax.py | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/test_self_attention.py index e14d2ebaa20f..91bc4f9ae9df 100644 --- a/tests/test_kernels/test_self_attention.py +++ b/tests/test_kernels/test_self_attention.py @@ -8,6 +8,10 @@ def test_qkv_matmul(): + cuda_version = float(torch.version.cuda) + if cuda_version <= 11.4: + return + qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) scale = 1.2 head_size = 32 @@ -66,6 +70,10 @@ def self_attention_compute_using_torch(qkv, scale, head_size ): + cuda_version = float(torch.version.cuda) + if cuda_version <= 11.4: + return + batches = qkv.shape[0] d_model = qkv.shape[-1] // 3 num_of_heads = d_model // head_size diff --git a/tests/test_kernels/test_softmax.py b/tests/test_kernels/test_softmax.py index 5efdbe09b39a..5fbbb66d50ea 100644 --- a/tests/test_kernels/test_softmax.py +++ b/tests/test_kernels/test_softmax.py @@ -5,6 +5,10 @@ from colossalai.kernel.triton.ops import softmax def test_softmax_op(): + cuda_version = float(torch.version.cuda) + if cuda_version <= 11.4: + return + device = "cuda" data_samples = [ torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32), From 585ab711c2079b176eba16878e157a0cb0875c19 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 18 Jul 2023 12:09:21 +0800 Subject: [PATCH 18/22] add --- colossalai/kernel/triton/ops.py | 385 +++++++++--------- colossalai/kernel/triton/qkv_matmul_kernel.py | 193 ++++----- colossalai/kernel/triton/softmax_kernel.py | 75 ++-- tests/test_kernels/test_softmax.py | 2 +- 4 files changed, 335 insertions(+), 320 deletions(-) diff --git a/colossalai/kernel/triton/ops.py b/colossalai/kernel/triton/ops.py index a72487e9c13c..5e8d4ba3ec99 100644 --- a/colossalai/kernel/triton/ops.py +++ b/colossalai/kernel/triton/ops.py @@ -1,204 +1,209 @@ import torch from torch import nn -import triton -import triton.language as tl - -from .qkv_matmul_kernel import qkv_gemm_4d_kernel -from .softmax_kernel import softmax_kernel - - -def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float): - r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels - Args: - q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) - k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) - v (torch.Tensor): V embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) - input_mask (torch.Tensor): mask for softmax layer, shape should be (batch, num_heads, seq_lem, seq_len) - scale: the float scale value which is used to multiply with Q*K^T before doing softmax - - Return: - output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size) - """ - assert len(q.shape) == 4, "the shape of q val must be 4" - batches, M, H, K = q.shape - assert q.shape == k.shape, "the shape of q and the shape of k must be equal" - assert q.shape == v.shape, "the shape of q and the shape of v must be equal" - assert q.shape[-1] == k.shape[-1], "the last dimension of q and k must be equal" - - N = k.shape[1] - - # head_size * num_of_head - d_model = q.shape[-1] * q.shape[-2] - - score_output = torch.empty( - (batches, H, M, N), device=q.device, dtype=q.dtype) - - grid = lambda meta: ( - batches, - H, - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * - triton.cdiv(N, meta["BLOCK_SIZE_N"]), - ) - - qkv_gemm_4d_kernel[grid]( - q, k, score_output, - M, N, K, - q.stride(0), q.stride(2), q.stride(1), q.stride(3), - k.stride(0), k.stride(2), k.stride(3), k.stride(1), - score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3), - scale=scale, - # currently manually setting, later on we can use auto-tune config to match best setting - BLOCK_SIZE_M=64, - BLOCK_SIZE_N=32, - BLOCK_SIZE_K=32, - GROUP_SIZE_M=8, - ) - - softmax_output = torch.empty( - score_output.shape, device=score_output.device, dtype=score_output.dtype) - score_output_shape = score_output.shape - - score_output = score_output.view(-1, score_output.shape[-1]) - n_rows, n_cols = score_output.shape - - if n_rows <= 350000: +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + from .qkv_matmul_kernel import qkv_gemm_4d_kernel + from .softmax_kernel import softmax_kernel + + def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float): + r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels + Args: + q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) + k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) + v (torch.Tensor): V embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) + input_mask (torch.Tensor): mask for softmax layer, shape should be (batch, num_heads, seq_lem, seq_len) + scale: the float scale value which is used to multiply with Q*K^T before doing softmax + + Return: + output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size) + """ + assert len(q.shape) == 4, "the shape of q val must be 4" + batches, M, H, K = q.shape + assert q.shape == k.shape, "the shape of q and the shape of k must be equal" + assert q.shape == v.shape, "the shape of q and the shape of v must be equal" + assert q.shape[-1] == k.shape[-1], "the last dimension of q and k must be equal" + + N = k.shape[1] + + # head_size * num_of_head + d_model = q.shape[-1] * q.shape[-2] + + score_output = torch.empty( + (batches, H, M, N), device=q.device, dtype=q.dtype) + + grid = lambda meta: ( + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * + triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + qkv_gemm_4d_kernel[grid]( + q, k, score_output, + M, N, K, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(3), k.stride(1), + score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3), + scale=scale, + # currently manually setting, later on we can use auto-tune config to match best setting + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=32, + BLOCK_SIZE_K=32, + GROUP_SIZE_M=8, + ) - block_size = max(triton.next_power_of_2(n_cols), 2) - num_warps = 4 - if block_size >= 4096: - num_warps = 16 - elif block_size >= 2048: - num_warps = 8 - else: + softmax_output = torch.empty( + score_output.shape, device=score_output.device, dtype=score_output.dtype) + score_output_shape = score_output.shape + + score_output = score_output.view(-1, score_output.shape[-1]) + n_rows, n_cols = score_output.shape + + if n_rows <= 350000: + + block_size = max(triton.next_power_of_2(n_cols), 2) num_warps = 4 + if block_size >= 4096: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + else: + num_warps = 4 + + softmax_kernel[(n_rows, )]( + softmax_output, + score_output, + score_output.stride(0), + n_cols, + mask_ptr = input_mask, + num_warps=num_warps, + BLOCK_SIZE=block_size, + ) - softmax_kernel[(n_rows, )]( - softmax_output, - score_output, - score_output.stride(0), - n_cols, - mask_ptr = input_mask, - num_warps=num_warps, - BLOCK_SIZE=block_size, - ) + else: + #TODO: change softmax kernel functions to make it suitable for large size dimension + softmax_output = torch.nn.functional.softmax(score_output, dim=-1) + softmax_output = softmax_output.view(*score_output_shape) - else: - #TODO: change softmax kernel functions to make it suitable for large size dimension - softmax_output = torch.nn.functional.softmax(score_output, dim=-1) - softmax_output = softmax_output.view(*score_output_shape) - - batches, H, M, K = softmax_output.shape - N = v.shape[-1] - - output = torch.empty( - (batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype) - - grid = lambda meta: ( - batches, - H, - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * - triton.cdiv(N, meta["BLOCK_SIZE_N"]), - ) - - qkv_gemm_4d_kernel[grid]( - softmax_output, v, output, - M, N, K, - softmax_output.stride(0), - softmax_output.stride(1), - softmax_output.stride(2), - softmax_output.stride(3), - v.stride(0), - v.stride(2), - v.stride(1), - v.stride(3), - output.stride(0), - output.stride(2), - output.stride(1), - output.stride(3), - BLOCK_SIZE_M=128, - BLOCK_SIZE_N=64, - BLOCK_SIZE_K=64, - GROUP_SIZE_M=8, - scale=-1, - ) - return output.view(batches, -1, d_model) - - -def self_attention_compute_using_triton(qkv, - input_mask, - layer_past, - alibi, - scale, - head_size, - triangular=False, - use_flash=False): - - assert qkv.is_contiguous() - assert alibi is None, "current triton self-attention does not support alibi" - batches = qkv.shape[0] - d_model = qkv.shape[-1] // 3 - num_of_heads = d_model // head_size - - q = qkv[:, :, :d_model] - k = qkv[:, :, d_model:d_model * 2] - v = qkv[:, :, d_model * 2:] - q = q.view(batches, -1, num_of_heads, head_size) - k = k.view(batches, -1, num_of_heads, head_size) - v = v.view(batches, -1, num_of_heads, head_size) - - data_output_triton = self_attention_forward_without_fusion( - q, k, v, input_mask, scale) - - return data_output_triton - - -def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor: - if mask is not None: - assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask" - assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention" - - hidden_dim = input.shape[-1] - output = torch.empty_like(input) - input = input.view(-1, hidden_dim) - if mask is not None: - mask = mask.view(-1, hidden_dim) - assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same" - - num_rows, num_cols = input.shape - block_size = max(triton.next_power_of_2(num_cols), 2) - num_warps = 16 - if block_size >= 4096: - num_warps = 16 - elif block_size >= 2048: - num_warps = 8 - else: - num_warps = 4 + batches, H, M, K = softmax_output.shape + N = v.shape[-1] - if num_rows <= 350000: - grid = (num_rows,) - softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps) - else: - grid = lambda meta: () + output = torch.empty( + (batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype) grid = lambda meta: ( - triton.cdiv(num_rows, meta["BLOCK_M"]), + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * + triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + qkv_gemm_4d_kernel[grid]( + softmax_output, v, output, + M, N, K, + softmax_output.stride(0), + softmax_output.stride(1), + softmax_output.stride(2), + softmax_output.stride(3), + v.stride(0), + v.stride(2), + v.stride(1), + v.stride(3), + output.stride(0), + output.stride(2), + output.stride(1), + output.stride(3), + BLOCK_SIZE_M=128, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=64, + GROUP_SIZE_M=8, + scale=-1, ) + return output.view(batches, -1, d_model) + + + def self_attention_compute_using_triton(qkv, + input_mask, + layer_past, + alibi, + scale, + head_size, + triangular=False, + use_flash=False): + + assert qkv.is_contiguous() + assert alibi is None, "current triton self-attention does not support alibi" + batches = qkv.shape[0] + d_model = qkv.shape[-1] // 3 + num_of_heads = d_model // head_size + + q = qkv[:, :, :d_model] + k = qkv[:, :, d_model:d_model * 2] + v = qkv[:, :, d_model * 2:] + q = q.view(batches, -1, num_of_heads, head_size) + k = k.view(batches, -1, num_of_heads, head_size) + v = v.view(batches, -1, num_of_heads, head_size) + + data_output_triton = self_attention_forward_without_fusion( + q, k, v, input_mask, scale) + + return data_output_triton + - BLOCK_M = 32 + def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor: + if mask is not None: + assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask" + assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention" + + hidden_dim = input.shape[-1] + output = torch.empty_like(input) + input = input.view(-1, hidden_dim) + if mask is not None: + mask = mask.view(-1, hidden_dim) + assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same" + + num_rows, num_cols = input.shape + block_size = max(triton.next_power_of_2(num_cols), 2) + num_warps = 16 if block_size >= 4096: - BLOCK_M = 4 + num_warps = 16 elif block_size >= 2048: - BLOCK_M = 8 - - softmax_kernel_2[grid](output_ptr = output, - input_ptr = input, - row_stride = input.stride(0), - n_rows = num_rows, - n_cols = num_cols, - mask_ptr = mask, - # currently manually setting up size - BLOCK_M = 32, - BLOCK_SIZE = block_size) - - return output \ No newline at end of file + num_warps = 8 + else: + num_warps = 4 + + if num_rows <= 350000: + grid = (num_rows,) + softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps) + else: + grid = lambda meta: () + + grid = lambda meta: ( + triton.cdiv(num_rows, meta["BLOCK_M"]), + ) + + BLOCK_M = 32 + if block_size >= 4096: + BLOCK_M = 4 + elif block_size >= 2048: + BLOCK_M = 8 + + softmax_kernel_2[grid](output_ptr = output, + input_ptr = input, + row_stride = input.stride(0), + n_rows = num_rows, + n_cols = num_cols, + mask_ptr = mask, + # currently manually setting up size + BLOCK_M = 32, + BLOCK_SIZE = block_size) + + return output \ No newline at end of file diff --git a/colossalai/kernel/triton/qkv_matmul_kernel.py b/colossalai/kernel/triton/qkv_matmul_kernel.py index 6789a4fa45ea..62fc6bba0360 100644 --- a/colossalai/kernel/triton/qkv_matmul_kernel.py +++ b/colossalai/kernel/triton/qkv_matmul_kernel.py @@ -1,104 +1,109 @@ import torch -import triton -import triton.language as tl +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") +if HAS_TRITON: + ''' + this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + ''' + @triton.jit + def qkv_gemm_4d_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_ab, + stride_ah, + stride_am, + stride_ak, + stride_bb, + stride_bh, + stride_bk, + stride_bn, + stride_cb, + stride_ch, + stride_cm, + stride_cn, + scale, + # Meta-parameters + BLOCK_SIZE_M : tl.constexpr = 64, + BLOCK_SIZE_N : tl.constexpr = 32, + BLOCK_SIZE_K : tl.constexpr = 32, + GROUP_SIZE_M : tl.constexpr = 8, + ): + r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer, + where score_matrix is softmax(Q*V^T/sqrt(hidden_size)) + Args: + a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K) + b_ptr(torch.Tensor): pointer to input tensor array (bs, N, h, K) or (bs, h, N, K) + c_ptr(torch.Tensor): pointer to output tensor array (bs, M, h, N) or (bs, h, M, N) + stride_ab(tl.constexpr): stride for bs-dimention for tensor array A + stride_ah(tl.constexpr): stride for h-dimention for tensor array A + stride_am(tl.constexpr): stride for m-dimention for tensor array A + stride_ak(tl.constexpr): stride for k-dimention for tensor array A + stride_bb(tl.constexpr): stride for bs-dimention for tensor array B + stride_bh(tl.constexpr): stride for h-dimention for tensor array B + stride_bk(tl.constexpr): stride for k-dimention for tensor array B + stride_bn(tl.constexpr): stride for n-dimention for tensor array B + stride_cb(tl.constexpr): stride for bs-dimention for tensor array output + stride_ch(tl.constexpr): stride for h-dimention for tensor array output + stride_cm(tl.constexpr): stride for m-dimention for tensor array output + stride_cn(tl.constexpr): stride for n-dimention for tensor array output + BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a + BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b + BLOCK_SIZE_K : tiling size for K-dimension of a and b + GROUP_SIZE_M : group size for reducing cache miss, more details: + """ -''' -this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html -''' -@triton.jit -def qkv_gemm_4d_kernel( - a_ptr, - b_ptr, - c_ptr, - M, - N, - K, - stride_ab, - stride_ah, - stride_am, - stride_ak, - stride_bb, - stride_bh, - stride_bk, - stride_bn, - stride_cb, - stride_ch, - stride_cm, - stride_cn, - scale, - # Meta-parameters - BLOCK_SIZE_M : tl.constexpr = 64, - BLOCK_SIZE_N : tl.constexpr = 32, - BLOCK_SIZE_K : tl.constexpr = 32, - GROUP_SIZE_M : tl.constexpr = 8, -): - r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer, - where score_matrix is softmax(Q*V^T/sqrt(hidden_size)) - Args: - a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K) - b_ptr(torch.Tensor): pointer to input tensor array (bs, N, h, K) or (bs, h, N, K) - c_ptr(torch.Tensor): pointer to output tensor array (bs, M, h, N) or (bs, h, M, N) - stride_ab(tl.constexpr): stride for bs-dimention for tensor array A - stride_ah(tl.constexpr): stride for h-dimention for tensor array A - stride_am(tl.constexpr): stride for m-dimention for tensor array A - stride_ak(tl.constexpr): stride for k-dimention for tensor array A - stride_bb(tl.constexpr): stride for bs-dimention for tensor array B - stride_bh(tl.constexpr): stride for h-dimention for tensor array B - stride_bk(tl.constexpr): stride for k-dimention for tensor array B - stride_bn(tl.constexpr): stride for n-dimention for tensor array B - stride_cb(tl.constexpr): stride for bs-dimention for tensor array output - stride_ch(tl.constexpr): stride for h-dimention for tensor array output - stride_cm(tl.constexpr): stride for m-dimention for tensor array output - stride_cn(tl.constexpr): stride for n-dimention for tensor array output - BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a - BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b - BLOCK_SIZE_K : tiling size for K-dimension of a and b - GROUP_SIZE_M : group size for reducing cache miss, more details: - """ + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + batch = tl.program_id(axis = 0) + head = tl.program_id(axis = 1) + pid = tl.program_id(axis = 2) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - batch = tl.program_id(axis = 0) - head = tl.program_id(axis = 1) - pid = tl.program_id(axis = 2) + # the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m - # the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah + + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)) + b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah + - (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)) - b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh + - (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, K, BLOCK_SIZE_K): - a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K) - b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N) - a = tl.load(a_ptrs, mask=a_mask, other=0.) - b = tl.load(b_ptrs, mask=b_mask, other=0.) - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - accumulator = accumulator.to(c_ptr.dtype.element_ty) - if scale > 0: - accumulator = accumulator * scale.to(c_ptr.dtype.element_ty) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K) + b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=a_mask, other=0.) + b = tl.load(b_ptrs, mask=b_mask, other=0.) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + accumulator = accumulator.to(c_ptr.dtype.element_ty) + if scale > 0: + accumulator = accumulator * scale.to(c_ptr.dtype.element_ty) + - offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] + - stride_cn * offs_accumu_n[None, :]) - accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N) - tl.store(c_ptrs, accumulator, mask=accumulator_mask) + offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] + + stride_cn * offs_accumu_n[None, :]) + accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N) + tl.store(c_ptrs, accumulator, mask=accumulator_mask) diff --git a/colossalai/kernel/triton/softmax_kernel.py b/colossalai/kernel/triton/softmax_kernel.py index 2c0f9d7b84d2..c215890badff 100644 --- a/colossalai/kernel/triton/softmax_kernel.py +++ b/colossalai/kernel/triton/softmax_kernel.py @@ -1,39 +1,44 @@ -import triton -import triton.language as tl -''' -softmax kernel is modified based on -https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py -''' +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") +if HAS_TRITON: + ''' + softmax kernel is modified based on + https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py + ''' + @triton.jit + def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr): + r""" the kernel function for implementing softmax operator + Args: + output_ptr: the output after finishing softmax operation, (N, hidden_dim) + input_ptr: the tensor of input, shape should be (N, hidden_dim) + n_cols(tl.constexpr): the number of cols of input + BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim + """ + row_idx = tl.program_id(0) + row_start_ptr = input_ptr + row_idx * row_stride + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32) + row_minus_max = row - tl.max(row, axis=0) -@triton.jit -def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr): - r""" the kernel function for implementing softmax operator - Args: - output_ptr: the output after finishing softmax operation, (N, hidden_dim) - input_ptr: the tensor of input, shape should be (N, hidden_dim) - n_cols(tl.constexpr): the number of cols of input - BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim - """ - row_idx = tl.program_id(0) - row_start_ptr = input_ptr + row_idx * row_stride - col_offsets = tl.arange(0, BLOCK_SIZE) - input_ptrs = row_start_ptr + col_offsets - row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32) - row_minus_max = row - tl.max(row, axis=0) + if mask_ptr is not None: + # load mask into SRAM + mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets + mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32) - if mask_ptr is not None: - # load mask into SRAM - mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets - mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32) + # update + row_minus_max = row_minus_max + mask - # update - row_minus_max = row_minus_max + mask - - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - output_row_start_ptr = output_ptr + row_idx * row_stride - output_ptrs = output_row_start_ptr + col_offsets - # Write back output to DRAM - tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) \ No newline at end of file + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + output_row_start_ptr = output_ptr + row_idx * row_stride + output_ptrs = output_row_start_ptr + col_offsets + # Write back output to DRAM + tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) \ No newline at end of file diff --git a/tests/test_kernels/test_softmax.py b/tests/test_kernels/test_softmax.py index 5fbbb66d50ea..e98829826db7 100644 --- a/tests/test_kernels/test_softmax.py +++ b/tests/test_kernels/test_softmax.py @@ -8,7 +8,7 @@ def test_softmax_op(): cuda_version = float(torch.version.cuda) if cuda_version <= 11.4: return - + device = "cuda" data_samples = [ torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32), From 9c80c4559982f267a90ad63b6dbe3dbadbef1031 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 18 Jul 2023 12:32:38 +0800 Subject: [PATCH 19/22] fixed pytest checking --- tests/test_kernels/test_self_attention.py | 25 ++++++++++++++--------- tests/test_kernels/test_softmax.py | 7 ++----- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/test_self_attention.py index 91bc4f9ae9df..cf41c6b66138 100644 --- a/tests/test_kernels/test_self_attention.py +++ b/tests/test_kernels/test_self_attention.py @@ -1,4 +1,4 @@ -import triton +import pytest import torch from torch import nn import torch.nn.functional as F @@ -6,12 +6,18 @@ from colossalai.kernel.triton.ops import self_attention_compute_using_triton from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") -def test_qkv_matmul(): - cuda_version = float(torch.version.cuda) - if cuda_version <= 11.4: - return +@pytest.mark.skipif(float(torch.version.cuda) <= 11.4, + reason="triton requires cuda version to be higher than 11.4") +def test_qkv_matmul(): qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) scale = 1.2 head_size = 32 @@ -70,10 +76,7 @@ def self_attention_compute_using_torch(qkv, scale, head_size ): - cuda_version = float(torch.version.cuda) - if cuda_version <= 11.4: - return - + batches = qkv.shape[0] d_model = qkv.shape[-1] // 3 num_of_heads = d_model // head_size @@ -102,8 +105,10 @@ def self_attention_compute_using_torch(qkv, return res.view(batches, -1, d_model), score_output, softmax_output - +@pytest.mark.skipif(float(torch.version.cuda) <= 11.4, + reason="triton requires cuda version to be higher than 11.4") def test_self_atttention_test(): + qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch( qkv.clone(), diff --git a/tests/test_kernels/test_softmax.py b/tests/test_kernels/test_softmax.py index e98829826db7..3944dba7525a 100644 --- a/tests/test_kernels/test_softmax.py +++ b/tests/test_kernels/test_softmax.py @@ -4,12 +4,9 @@ from colossalai.kernel.triton.ops import softmax +@pytest.mark.skipif(float(torch.version.cuda) <= 11.7, + reason="triton requires cuda version to be higher than 11.4") def test_softmax_op(): - cuda_version = float(torch.version.cuda) - if cuda_version <= 11.4: - return - - device = "cuda" data_samples = [ torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32), torch.randn((320, 320, 78), device = "cuda", dtype = torch.float32), From d94a4535d787211c31e99539353f43b4b1bdf759 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 18 Jul 2023 12:33:49 +0800 Subject: [PATCH 20/22] add cuda check --- tests/test_kernels/test_softmax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_kernels/test_softmax.py b/tests/test_kernels/test_softmax.py index 3944dba7525a..c4ef2798b7b0 100644 --- a/tests/test_kernels/test_softmax.py +++ b/tests/test_kernels/test_softmax.py @@ -4,7 +4,7 @@ from colossalai.kernel.triton.ops import softmax -@pytest.mark.skipif(float(torch.version.cuda) <= 11.7, +@pytest.mark.skipif(float(torch.version.cuda) <= 11.4, reason="triton requires cuda version to be higher than 11.4") def test_softmax_op(): data_samples = [ From bebac5e864f931a0db4a70b7de57ac805956ab40 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 18 Jul 2023 13:51:28 +0800 Subject: [PATCH 21/22] fix cuda version --- tests/test_kernels/test_self_attention.py | 12 ++++++------ tests/test_kernels/test_softmax.py | 6 ++++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/test_self_attention.py index cf41c6b66138..af6a701dafab 100644 --- a/tests/test_kernels/test_self_attention.py +++ b/tests/test_kernels/test_self_attention.py @@ -1,4 +1,5 @@ -import pytest +import pytest +from packaging import version import torch from torch import nn import torch.nn.functional as F @@ -14,9 +15,9 @@ HAS_TRITON = False print("please install triton from https://github.com/openai/triton") +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') -@pytest.mark.skipif(float(torch.version.cuda) <= 11.4, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4") def test_qkv_matmul(): qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) scale = 1.2 @@ -105,10 +106,9 @@ def self_attention_compute_using_torch(qkv, return res.view(batches, -1, d_model), score_output, softmax_output -@pytest.mark.skipif(float(torch.version.cuda) <= 11.4, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4") def test_self_atttention_test(): - + qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch( qkv.clone(), diff --git a/tests/test_kernels/test_softmax.py b/tests/test_kernels/test_softmax.py index c4ef2798b7b0..843d811d019c 100644 --- a/tests/test_kernels/test_softmax.py +++ b/tests/test_kernels/test_softmax.py @@ -1,11 +1,13 @@ import pytest +from packaging import version import torch from torch import nn from colossalai.kernel.triton.ops import softmax -@pytest.mark.skipif(float(torch.version.cuda) <= 11.4, - reason="triton requires cuda version to be higher than 11.4") +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4") def test_softmax_op(): data_samples = [ torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32), From 4aaacff62ba194aaed59cf17a2b5cb5ffc5ee994 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 18 Jul 2023 14:04:23 +0800 Subject: [PATCH 22/22] fix typo --- tests/test_kernels/test_self_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/test_self_attention.py index af6a701dafab..b316404a58db 100644 --- a/tests/test_kernels/test_self_attention.py +++ b/tests/test_kernels/test_self_attention.py @@ -128,7 +128,7 @@ def test_self_atttention_test(): triangular=True) check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2) - assert check is True, "the triton ouput is not matched with torch output" + assert check is True, "the triton output is not matched with torch output" if __name__ == "__main__":