From f597cc31992fc416f633d5dc1a1fd0dc3db71e5b Mon Sep 17 00:00:00 2001 From: Cuiqing Li Date: Tue, 18 Jul 2023 23:53:38 +0800 Subject: [PATCH 1/4] [Kernels] added triton-implemented of self attention for colossal-ai (#4241) * added softmax kernel * added qkv_kernel * added ops * adding tests * upload tets * fix tests * debugging * debugging tests * debugging * added * fixed errors * added softmax kernel * clean codes * added tests * update tests * update tests * added attention * add * fixed pytest checking * add cuda check * fix cuda version * fix typo --- colossalai/kernel/triton/ops.py | 209 ++++++++++++++++++ colossalai/kernel/triton/qkv_matmul_kernel.py | 109 +++++++++ colossalai/kernel/triton/softmax_kernel.py | 44 ++++ tests/test_kernels/test_self_attention.py | 136 ++++++++++++ tests/test_kernels/test_softmax.py | 27 +++ 5 files changed, 525 insertions(+) create mode 100644 colossalai/kernel/triton/ops.py create mode 100644 colossalai/kernel/triton/qkv_matmul_kernel.py create mode 100644 colossalai/kernel/triton/softmax_kernel.py create mode 100644 tests/test_kernels/test_self_attention.py create mode 100644 tests/test_kernels/test_softmax.py diff --git a/colossalai/kernel/triton/ops.py b/colossalai/kernel/triton/ops.py new file mode 100644 index 000000000000..5e8d4ba3ec99 --- /dev/null +++ b/colossalai/kernel/triton/ops.py @@ -0,0 +1,209 @@ +import torch +from torch import nn + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + from .qkv_matmul_kernel import qkv_gemm_4d_kernel + from .softmax_kernel import softmax_kernel + + def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float): + r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels + Args: + q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) + k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) + v (torch.Tensor): V embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) + input_mask (torch.Tensor): mask for softmax layer, shape should be (batch, num_heads, seq_lem, seq_len) + scale: the float scale value which is used to multiply with Q*K^T before doing softmax + + Return: + output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size) + """ + assert len(q.shape) == 4, "the shape of q val must be 4" + batches, M, H, K = q.shape + assert q.shape == k.shape, "the shape of q and the shape of k must be equal" + assert q.shape == v.shape, "the shape of q and the shape of v must be equal" + assert q.shape[-1] == k.shape[-1], "the last dimension of q and k must be equal" + + N = k.shape[1] + + # head_size * num_of_head + d_model = q.shape[-1] * q.shape[-2] + + score_output = torch.empty( + (batches, H, M, N), device=q.device, dtype=q.dtype) + + grid = lambda meta: ( + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * + triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + qkv_gemm_4d_kernel[grid]( + q, k, score_output, + M, N, K, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(3), k.stride(1), + score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3), + scale=scale, + # currently manually setting, later on we can use auto-tune config to match best setting + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=32, + BLOCK_SIZE_K=32, + GROUP_SIZE_M=8, + ) + + softmax_output = torch.empty( + score_output.shape, device=score_output.device, dtype=score_output.dtype) + score_output_shape = score_output.shape + + score_output = score_output.view(-1, score_output.shape[-1]) + n_rows, n_cols = score_output.shape + + if n_rows <= 350000: + + block_size = max(triton.next_power_of_2(n_cols), 2) + num_warps = 4 + if block_size >= 4096: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + else: + num_warps = 4 + + softmax_kernel[(n_rows, )]( + softmax_output, + score_output, + score_output.stride(0), + n_cols, + mask_ptr = input_mask, + num_warps=num_warps, + BLOCK_SIZE=block_size, + ) + + else: + #TODO: change softmax kernel functions to make it suitable for large size dimension + softmax_output = torch.nn.functional.softmax(score_output, dim=-1) + softmax_output = softmax_output.view(*score_output_shape) + + batches, H, M, K = softmax_output.shape + N = v.shape[-1] + + output = torch.empty( + (batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype) + + grid = lambda meta: ( + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * + triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + qkv_gemm_4d_kernel[grid]( + softmax_output, v, output, + M, N, K, + softmax_output.stride(0), + softmax_output.stride(1), + softmax_output.stride(2), + softmax_output.stride(3), + v.stride(0), + v.stride(2), + v.stride(1), + v.stride(3), + output.stride(0), + output.stride(2), + output.stride(1), + output.stride(3), + BLOCK_SIZE_M=128, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=64, + GROUP_SIZE_M=8, + scale=-1, + ) + return output.view(batches, -1, d_model) + + + def self_attention_compute_using_triton(qkv, + input_mask, + layer_past, + alibi, + scale, + head_size, + triangular=False, + use_flash=False): + + assert qkv.is_contiguous() + assert alibi is None, "current triton self-attention does not support alibi" + batches = qkv.shape[0] + d_model = qkv.shape[-1] // 3 + num_of_heads = d_model // head_size + + q = qkv[:, :, :d_model] + k = qkv[:, :, d_model:d_model * 2] + v = qkv[:, :, d_model * 2:] + q = q.view(batches, -1, num_of_heads, head_size) + k = k.view(batches, -1, num_of_heads, head_size) + v = v.view(batches, -1, num_of_heads, head_size) + + data_output_triton = self_attention_forward_without_fusion( + q, k, v, input_mask, scale) + + return data_output_triton + + + def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor: + if mask is not None: + assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask" + assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention" + + hidden_dim = input.shape[-1] + output = torch.empty_like(input) + input = input.view(-1, hidden_dim) + if mask is not None: + mask = mask.view(-1, hidden_dim) + assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same" + + num_rows, num_cols = input.shape + block_size = max(triton.next_power_of_2(num_cols), 2) + num_warps = 16 + if block_size >= 4096: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + else: + num_warps = 4 + + if num_rows <= 350000: + grid = (num_rows,) + softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps) + else: + grid = lambda meta: () + + grid = lambda meta: ( + triton.cdiv(num_rows, meta["BLOCK_M"]), + ) + + BLOCK_M = 32 + if block_size >= 4096: + BLOCK_M = 4 + elif block_size >= 2048: + BLOCK_M = 8 + + softmax_kernel_2[grid](output_ptr = output, + input_ptr = input, + row_stride = input.stride(0), + n_rows = num_rows, + n_cols = num_cols, + mask_ptr = mask, + # currently manually setting up size + BLOCK_M = 32, + BLOCK_SIZE = block_size) + + return output \ No newline at end of file diff --git a/colossalai/kernel/triton/qkv_matmul_kernel.py b/colossalai/kernel/triton/qkv_matmul_kernel.py new file mode 100644 index 000000000000..62fc6bba0360 --- /dev/null +++ b/colossalai/kernel/triton/qkv_matmul_kernel.py @@ -0,0 +1,109 @@ +import torch +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + + +if HAS_TRITON: + ''' + this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + ''' + @triton.jit + def qkv_gemm_4d_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_ab, + stride_ah, + stride_am, + stride_ak, + stride_bb, + stride_bh, + stride_bk, + stride_bn, + stride_cb, + stride_ch, + stride_cm, + stride_cn, + scale, + # Meta-parameters + BLOCK_SIZE_M : tl.constexpr = 64, + BLOCK_SIZE_N : tl.constexpr = 32, + BLOCK_SIZE_K : tl.constexpr = 32, + GROUP_SIZE_M : tl.constexpr = 8, + ): + r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer, + where score_matrix is softmax(Q*V^T/sqrt(hidden_size)) + Args: + a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K) + b_ptr(torch.Tensor): pointer to input tensor array (bs, N, h, K) or (bs, h, N, K) + c_ptr(torch.Tensor): pointer to output tensor array (bs, M, h, N) or (bs, h, M, N) + stride_ab(tl.constexpr): stride for bs-dimention for tensor array A + stride_ah(tl.constexpr): stride for h-dimention for tensor array A + stride_am(tl.constexpr): stride for m-dimention for tensor array A + stride_ak(tl.constexpr): stride for k-dimention for tensor array A + stride_bb(tl.constexpr): stride for bs-dimention for tensor array B + stride_bh(tl.constexpr): stride for h-dimention for tensor array B + stride_bk(tl.constexpr): stride for k-dimention for tensor array B + stride_bn(tl.constexpr): stride for n-dimention for tensor array B + stride_cb(tl.constexpr): stride for bs-dimention for tensor array output + stride_ch(tl.constexpr): stride for h-dimention for tensor array output + stride_cm(tl.constexpr): stride for m-dimention for tensor array output + stride_cn(tl.constexpr): stride for n-dimention for tensor array output + BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a + BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b + BLOCK_SIZE_K : tiling size for K-dimension of a and b + GROUP_SIZE_M : group size for reducing cache miss, more details: + """ + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + batch = tl.program_id(axis = 0) + head = tl.program_id(axis = 1) + pid = tl.program_id(axis = 2) + + # the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah + + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)) + b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K) + b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=a_mask, other=0.) + b = tl.load(b_ptrs, mask=b_mask, other=0.) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + accumulator = accumulator.to(c_ptr.dtype.element_ty) + if scale > 0: + accumulator = accumulator * scale.to(c_ptr.dtype.element_ty) + + + offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] + + stride_cn * offs_accumu_n[None, :]) + accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N) + tl.store(c_ptrs, accumulator, mask=accumulator_mask) diff --git a/colossalai/kernel/triton/softmax_kernel.py b/colossalai/kernel/triton/softmax_kernel.py new file mode 100644 index 000000000000..c215890badff --- /dev/null +++ b/colossalai/kernel/triton/softmax_kernel.py @@ -0,0 +1,44 @@ +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + ''' + softmax kernel is modified based on + https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py + ''' + @triton.jit + def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr): + r""" the kernel function for implementing softmax operator + Args: + output_ptr: the output after finishing softmax operation, (N, hidden_dim) + input_ptr: the tensor of input, shape should be (N, hidden_dim) + n_cols(tl.constexpr): the number of cols of input + BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim + """ + row_idx = tl.program_id(0) + row_start_ptr = input_ptr + row_idx * row_stride + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32) + row_minus_max = row - tl.max(row, axis=0) + + if mask_ptr is not None: + # load mask into SRAM + mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets + mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32) + + # update + row_minus_max = row_minus_max + mask + + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + output_row_start_ptr = output_ptr + row_idx * row_stride + output_ptrs = output_row_start_ptr + col_offsets + # Write back output to DRAM + tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) \ No newline at end of file diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/test_self_attention.py new file mode 100644 index 000000000000..b316404a58db --- /dev/null +++ b/tests/test_kernels/test_self_attention.py @@ -0,0 +1,136 @@ +import pytest +from packaging import version +import torch +from torch import nn +import torch.nn.functional as F + +from colossalai.kernel.triton.ops import self_attention_compute_using_triton +from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4") +def test_qkv_matmul(): + qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) + scale = 1.2 + head_size = 32 + batches = qkv.shape[0] + d_model = qkv.shape[-1] // 3 + num_of_heads = d_model // head_size + + q = qkv[:, :, :d_model] + k = qkv[:, :, d_model:d_model * 2] + + q = q.view(batches, -1, num_of_heads, head_size) + k = k.view(batches, -1, num_of_heads, head_size) + q_copy = q.clone() + k_copy = k.clone() + q = torch.transpose(q, 1, 2).contiguous() + k = torch.transpose(k, 1, 2).contiguous() + k = torch.transpose(k, 2, 3).contiguous() + + torch_ouput = torch.einsum('bnij,bnjk->bnik', q, k) + torch_ouput *= 1.2 + + q, k = q_copy, k_copy + batches, M, H, K = q.shape + N = k.shape[1] + score_output = torch.empty( + (batches, H, M, N), device=q.device, dtype=q.dtype) + + grid = lambda meta: ( + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * + triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + K = q.shape[3] + qkv_gemm_4d_kernel[grid]( + q, k, score_output, + M, N, K, + q.stride(0), q.stride(2), q.stride(1), q.stride(3), + k.stride(0), k.stride(2), k.stride(3), k.stride(1), + score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3), + scale=scale, + # currently manually setting, later on we can use auto-tune config to match best setting + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=32, + BLOCK_SIZE_K=32, + GROUP_SIZE_M=8, + ) + + check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5) + assert check is True, "the outputs of triton and torch are not matched" + + +def self_attention_compute_using_torch(qkv, + input_mask, + scale, + head_size + ): + + batches = qkv.shape[0] + d_model = qkv.shape[-1] // 3 + num_of_heads = d_model // head_size + + q = qkv[:, :, :d_model] + k = qkv[:, :, d_model:d_model * 2] + v = qkv[:, :, d_model * 2:] + q = q.view(batches, -1, num_of_heads, head_size) + k = k.view(batches, -1, num_of_heads, head_size) + v = v.view(batches, -1, num_of_heads, head_size) + + q = torch.transpose(q, 1, 2).contiguous() + k = torch.transpose(k, 1, 2).contiguous() + v = torch.transpose(v, 1, 2).contiguous() + + k = torch.transpose(k, -1, -2).contiguous() + + score_output = torch.einsum('bnij,bnjk->bnik', q, k) + score_output *= scale + + softmax_output = F.softmax(score_output, dim = -1) + res = torch.einsum('bnij,bnjk->bnik', softmax_output, v) + res = torch.transpose(res, 1, 2) + res = res.contiguous() + + + return res.view(batches, -1, d_model), score_output, softmax_output + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4") +def test_self_atttention_test(): + + qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) + data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch( + qkv.clone(), + input_mask = None, + scale = 1.2, + head_size = 32 + ) + + data_output_triton = self_attention_compute_using_triton( + qkv.clone(), + alibi=None, + head_size=32, + scale=1.2, + input_mask=None, + layer_past=None, + use_flash=False, + triangular=True) + + check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2) + assert check is True, "the triton output is not matched with torch output" + + +if __name__ == "__main__": + test_qkv_matmul() + test_self_atttention_test() \ No newline at end of file diff --git a/tests/test_kernels/test_softmax.py b/tests/test_kernels/test_softmax.py new file mode 100644 index 000000000000..843d811d019c --- /dev/null +++ b/tests/test_kernels/test_softmax.py @@ -0,0 +1,27 @@ +import pytest +from packaging import version +import torch +from torch import nn + +from colossalai.kernel.triton.ops import softmax + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4") +def test_softmax_op(): + data_samples = [ + torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32), + torch.randn((320, 320, 78), device = "cuda", dtype = torch.float32), + torch.randn((2345, 4, 5, 64), device = "cuda", dtype = torch.float16) + ] + + for data in data_samples: + module = nn.Softmax(dim = -1) + data_torch_out = module(data) + data_triton_out = softmax(data) + check = torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-3, atol=1e-3) + assert check is True, "softmax outputs from triton and torch are not matched" + + +if __name__ == "__main__": + test_softmax_op() \ No newline at end of file From 35748ac64095e231f50d85c24532fa2b58f16439 Mon Sep 17 00:00:00 2001 From: Cuiqing Li Date: Tue, 18 Jul 2023 23:53:38 +0800 Subject: [PATCH 2/4] add bloom attention policy --- .../models/bloom/triton_attention_forward.py | 143 ++++++++++++++++ .../coati/trainer/strategies/tp/policy.py | 11 +- colossalai/kernel/triton/ops.py | 155 +++++++++++++++++- colossalai/kernel/triton/qkv_matmul_kernel.py | 108 ++++++++++++ 4 files changed, 415 insertions(+), 2 deletions(-) create mode 100644 applications/Chat/coati/models/bloom/triton_attention_forward.py diff --git a/applications/Chat/coati/models/bloom/triton_attention_forward.py b/applications/Chat/coati/models/bloom/triton_attention_forward.py new file mode 100644 index 000000000000..ec5c7958dcca --- /dev/null +++ b/applications/Chat/coati/models/bloom/triton_attention_forward.py @@ -0,0 +1,143 @@ +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from torch.nn import functional as F +import torch.distributed as dist +from transformers.models.bloom.configuration_bloom import BloomConfig + +from inference.ops.self_attention import compute_attention_for_bloom + +def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: + """ + Dropout add function + + Args: + x (`torch.tensor`, *required*): + input tensor + residual (`torch.tensor`, *required*): + esidual tensor + prob (`float`, *required*): + dropout probability + training (`bool`, *required*): + training mode + """ + out = F.dropout(x, p=prob, training=training) + out = residual + out + return out + +class TritonBloomAttention(nn.Module): + def __init__(self, config: BloomConfig): + super().__init__() + + self.pretraining_tp = config.pretraining_tp + self.slow_but_exact = config.slow_but_exact + + self.hidden_size = config.hidden_size + self.num_heads = config.n_head + self.head_dim = self.hidden_size // self.num_heads + self.split_size = self.hidden_size + self.hidden_dropout = config.hidden_dropout + + if self.head_dim * self.num_heads != self.hidden_size: + raise ValueError( + f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" + f" {self.num_heads})." + ) + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.beta = 1.0 + + self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) + self.dense = nn.Linear(self.hidden_size, self.hidden_size) + self.attention_dropout = nn.Dropout(config.attention_dropout) + + def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory + storage as `fused_qkv` + + Args: + fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim] + + Returns: + query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] + value: [batch_size, seq_length, num_heads, head_dim] + """ + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) + return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + ): + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + q_length = query_layer.shape[1] + batch_size = query_layer.shape[0] + num_heads = query_layer.shape[2] + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, head_dim, kv_length] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=2) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, _, kv_length = key_layer.shape + alibi = alibi.view(batch_size, num_heads, q_length, -1) + + context_layer = compute_attention_for_bloom(q = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim), + k = key_layer.view(batch_size, self.num_heads, self.head_dim, kv_length), + v = value_layer.view(batch_size, self.num_heads, kv_length, self.head_dim), + alibi = alibi, + beta = self.beta, + scale = self.inv_norm_factor, + attention_mask = attention_mask, + drop_out = self.hidden_dropout, + head_mask = head_mask, + layer_past = layer_past, + use_cache = True, + ) + + if use_cache: + present = (key_layer, value_layer) + else: + present = None + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + + outputs = (output_tensor, present) + if output_attentions: + outputs += (attention_probs,) + + return outputs diff --git a/applications/Chat/coati/trainer/strategies/tp/policy.py b/applications/Chat/coati/trainer/strategies/tp/policy.py index 0a5b8d44c26b..2049edddef1b 100644 --- a/applications/Chat/coati/trainer/strategies/tp/policy.py +++ b/applications/Chat/coati/trainer/strategies/tp/policy.py @@ -3,11 +3,14 @@ from typing import Dict, Type import torch.nn as nn -from coati.models.lora import LoraLinear +from torch.nn import functional as F from torch.nn import Module +from transformers.models.bloom.configuration_bloom import BloomConfig from transformers.models.bloom.modeling_bloom import BloomAttention, BloomForCausalLM, BloomMLP from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer +from coati.models.lora import LoraLinear +from coati.modles.bloom.triton_attention_forward import TritonBloomAttention from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.lazy import LazyTensor @@ -132,6 +135,12 @@ def replace(self, module: Module) -> bool: module.forward = MethodType(bloom_attn_fwd, module) return False +class BloomAttentionTritonPolicy(Policy): + + def replace(self, module: Module) -> bool: + assert isinstance(module, BloomAttention) + module.forward = MethodType(TritonBloomAttention.forward, module) + return False class BloomMLPPolicy(Policy): diff --git a/colossalai/kernel/triton/ops.py b/colossalai/kernel/triton/ops.py index 5e8d4ba3ec99..4927ef03e286 100644 --- a/colossalai/kernel/triton/ops.py +++ b/colossalai/kernel/triton/ops.py @@ -10,7 +10,7 @@ print("please install triton from https://github.com/openai/triton") if HAS_TRITON: - from .qkv_matmul_kernel import qkv_gemm_4d_kernel + from .qkv_matmul_kernel import qkv_gemm_4d_kernel, qkv_gemm_4d_kernel_alibi from .softmax_kernel import softmax_kernel def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float): @@ -128,6 +128,159 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t scale=-1, ) return output.view(batches, -1, d_model) + + def compute_attention_for_bloom(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + alibi: torch.Tensor, + beta: torch.float32 = 1, + scale: torch.float32 = 1.2, + attention_mask: torch.Tensor = None, + drop_out: torch.float32 = -1, + head_mask: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = True + ): + r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels to used for bloom attention + Args: + q (torch.Tensor): Q embedding in attention layer, shape should be (batch, num_heads, seq_len, head_size) + k (torch.Tensor): K embedding in attention layer, shape should be (batch, num_heads, head_size, kv_length) + v (torch.Tensor): V embedding in attention layer, shape should be (batch, num_heads, kv_length, head_size) + alibi(torch.Tensor): bias for qk^T GEMM, shape should be (batch, H, q_length, kv_length) + input_mask (torch.Tensor): mask for softmax layer, shape should be (batch, num_heads, seq_lem, seq_len) + scale: the float scale value which is used to multiply with Q*K^T before doing softmax + beta: the float value for alibi bias matrix + + Return: + output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size) + """ + + assert len(q.shape) == len(k.shape), "the dimensions must be matched" + assert len(q.shape) == len(v.shape), "the dimensions must be matched" + assert len(q.shape) == 4, "the length of input q must be 4, which is (batch, seq_len, num_heads, head_dim)" + + batches, H, M, K = q.shape + d_model = q.shape[1] * q.shape[3] + + # k shape: (batches, num_heads, head_dim (K), seq_k(N)) + N = k.shape[-1] + + score_output = torch.empty( + (batches, H, M, N), device=q.device, dtype=q.dtype) + + assert len(score_output) == len(alibi), "the length of alibi and score output should be matched" + assert score_output.shape[:-1] == alibi.shape[:-1], "the shape of alibi and score outout also should be the same" + alibi = alibi.expand(batches, H, M, N) + + grid = lambda meta: ( + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * + triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + qkv_gemm_4d_kernel_alibi[grid]( + q, k, alibi, score_output, + M, N, K, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3), + scale=scale, + beta=beta, + # currently manually setting, later on we can use auto-tune config to match best setting + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=32, + BLOCK_SIZE_K=32, + GROUP_SIZE_M=8, + num_stages=4, + ) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] + input_dtype = score_output.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16: + score_output = score_output.to(torch.float) + + if attention_mask is not None: + score_output = torch.masked_fill(score_output, attention_mask, torch.finfo(score_output.dtype).min) + + softmax_leading_size = batches * H * M + + if softmax_leading_size <= 350000: + + score_output = score_output.to(input_dtype) + softmax_output = torch.empty( + score_output.shape, device=score_output.device, dtype=score_output.dtype) + score_output_shape = score_output.shape + + score_output = score_output.view(-1, score_output.shape[-1]) + n_rows, n_cols = score_output.shape + + block_size = max(triton.next_power_of_2(n_cols), 2) + num_warps = 4 + if block_size >= 4096: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + else: + num_warps = 4 + + softmax_kernel[(n_rows, )]( + softmax_output, + score_output, + score_output.stride(0), + n_cols, + mask_ptr = None, + num_warps=num_warps, + BLOCK_SIZE=block_size, + ) + else: + # TODO: fix softmax layer to make kernel to be suitable for large size cases + softmax_output = F.softmax(score_output, dim=-1, dtype=torch.float32).to(input_dtype) + + if drop_out > 0 and drop_out < 1: + softmax_output = F.dropout(softmax_output, drop_out, False, True).to(input_dtype) + + if head_mask is not None: + softmax_output = softmax_output * head_mask + softmax_output = softmax_output.to(input_dtype) + + batches, H, M, K = softmax_output.shape + N = v.shape[-1] + + output = torch.empty( + (batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype) + + grid = lambda meta: ( + batches, + H, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * + triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + qkv_gemm_4d_kernel[grid]( + softmax_output, v, output, + M, N, K, + softmax_output.stride(0), + softmax_output.stride(1), + softmax_output.stride(2), + softmax_output.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + output.stride(0), + output.stride(2), + output.stride(1), + output.stride(3), + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=32, + BLOCK_SIZE_K=32, + GROUP_SIZE_M=8, + scale=-1, + ) + + return output.view(batches, -1 , d_model) def self_attention_compute_using_triton(qkv, diff --git a/colossalai/kernel/triton/qkv_matmul_kernel.py b/colossalai/kernel/triton/qkv_matmul_kernel.py index 62fc6bba0360..191397d9bc8d 100644 --- a/colossalai/kernel/triton/qkv_matmul_kernel.py +++ b/colossalai/kernel/triton/qkv_matmul_kernel.py @@ -107,3 +107,111 @@ def qkv_gemm_4d_kernel( stride_cn * offs_accumu_n[None, :]) accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N) tl.store(c_ptrs, accumulator, mask=accumulator_mask) + + + @triton.jit + def qkv_gemm_4d_kernel_alibi( + a_ptr, + b_ptr, + alibi_ptr, + c_ptr, + M, + N, + K, + stride_ab, + stride_ah, + stride_am, + stride_ak, + stride_bb, + stride_bh, + stride_bk, + stride_bn, + stride_cb, + stride_ch, + stride_cm, + stride_cn, + scale, + beta, + # Meta-parameters + BLOCK_SIZE_M : tl.constexpr = 64, + BLOCK_SIZE_N : tl.constexpr = 32, + BLOCK_SIZE_K : tl.constexpr = 32, + GROUP_SIZE_M : tl.constexpr = 8, + ): + r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer, + where score_matrix is softmax(Q*V^T/sqrt(hidden_size)) + Args: + a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K) + b_ptr(torch.Tensor): pointer to input tensor array (bs, N, h, K) or (bs, h, N, K) + c_ptr(torch.Tensor): pointer to output tensor array (bs, M, h, N) or (bs, h, M, N) + stride_ab(tl.constexpr): stride for bs-dimention for tensor array A + stride_ah(tl.constexpr): stride for h-dimention for tensor array A + stride_am(tl.constexpr): stride for m-dimention for tensor array A + stride_ak(tl.constexpr): stride for k-dimention for tensor array A + stride_bb(tl.constexpr): stride for bs-dimention for tensor array B + stride_bh(tl.constexpr): stride for h-dimention for tensor array B + stride_bk(tl.constexpr): stride for k-dimention for tensor array B + stride_bn(tl.constexpr): stride for n-dimention for tensor array B + stride_cb(tl.constexpr): stride for bs-dimention for tensor array output + stride_ch(tl.constexpr): stride for h-dimention for tensor array output + stride_cm(tl.constexpr): stride for m-dimention for tensor array output + stride_cn(tl.constexpr): stride for n-dimention for tensor array output + BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a + BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b + BLOCK_SIZE_K : tiling size for K-dimension of a and b + GROUP_SIZE_M : group size for reducing cache miss, more details: + """ + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + batch = tl.program_id(axis = 0) + head = tl.program_id(axis = 1) + pid = tl.program_id(axis = 2) + + # the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah + + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)) + b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K) + b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=a_mask, other=0.) + b = tl.load(b_ptrs, mask=b_mask, other=0.) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + accumulator = accumulator.to(c_ptr.dtype.element_ty) + if scale > 0: + accumulator = accumulator * scale.to(c_ptr.dtype.element_ty) + + + offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] + + stride_cn * offs_accumu_n[None, :]) + accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N) + + # load alibi + alibi_ptrs = (alibi_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] + + stride_cn * offs_accumu_n[None, :]) + alibi_vals = tl.load(alibi_ptrs, mask=accumulator_mask, other=0.) + + accumulator += (alibi_vals * beta).to(c_ptr.dtype.element_ty) + accumulator = accumulator.to(c_ptr.dtype.element_ty) + + tl.store(c_ptrs, accumulator, mask=accumulator_mask) From 254053e912e1c80c471ff1d62097fbb023654279 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Fri, 28 Jul 2023 12:53:05 +0800 Subject: [PATCH 3/4] fix path --- .../Chat/coati/models/bloom/triton_attention_forward.py | 2 +- colossalai/kernel/triton/ops.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/applications/Chat/coati/models/bloom/triton_attention_forward.py b/applications/Chat/coati/models/bloom/triton_attention_forward.py index ec5c7958dcca..e08d0eba14cb 100644 --- a/applications/Chat/coati/models/bloom/triton_attention_forward.py +++ b/applications/Chat/coati/models/bloom/triton_attention_forward.py @@ -7,7 +7,7 @@ import torch.distributed as dist from transformers.models.bloom.configuration_bloom import BloomConfig -from inference.ops.self_attention import compute_attention_for_bloom +from colossalai.kernel.triton.ops import compute_attention_for_bloom def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: """ diff --git a/colossalai/kernel/triton/ops.py b/colossalai/kernel/triton/ops.py index 4927ef03e286..ea6f5cb52e59 100644 --- a/colossalai/kernel/triton/ops.py +++ b/colossalai/kernel/triton/ops.py @@ -1,3 +1,5 @@ +from typing import Optional, Tuple + import torch from torch import nn From c6de4de64a7c7ac2c27575502209b2b52557390d Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Fri, 28 Jul 2023 14:25:56 +0800 Subject: [PATCH 4/4] change bloom attention --- .../models/bloom/triton_attention_forward.py | 44 ++----------------- 1 file changed, 3 insertions(+), 41 deletions(-) diff --git a/applications/Chat/coati/models/bloom/triton_attention_forward.py b/applications/Chat/coati/models/bloom/triton_attention_forward.py index e08d0eba14cb..caecba978731 100644 --- a/applications/Chat/coati/models/bloom/triton_attention_forward.py +++ b/applications/Chat/coati/models/bloom/triton_attention_forward.py @@ -6,6 +6,7 @@ from torch.nn import functional as F import torch.distributed as dist from transformers.models.bloom.configuration_bloom import BloomConfig +from transformers.models.bloom.model_bloom import BloomAttention from colossalai.kernel.triton.ops import compute_attention_for_bloom @@ -27,48 +28,9 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: out = residual + out return out -class TritonBloomAttention(nn.Module): +class TritonBloomAttention(BloomAttention): def __init__(self, config: BloomConfig): - super().__init__() - - self.pretraining_tp = config.pretraining_tp - self.slow_but_exact = config.slow_but_exact - - self.hidden_size = config.hidden_size - self.num_heads = config.n_head - self.head_dim = self.hidden_size // self.num_heads - self.split_size = self.hidden_size - self.hidden_dropout = config.hidden_dropout - - if self.head_dim * self.num_heads != self.hidden_size: - raise ValueError( - f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" - f" {self.num_heads})." - ) - - # Layer-wise attention scaling - self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) - self.beta = 1.0 - - self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) - self.dense = nn.Linear(self.hidden_size, self.hidden_size) - self.attention_dropout = nn.Dropout(config.attention_dropout) - - def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory - storage as `fused_qkv` - - Args: - fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim] - - Returns: - query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] - value: [batch_size, seq_length, num_heads, head_dim] - """ - batch_size, seq_length, three_times_hidden_size = fused_qkv.shape - fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) - return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] + super(TritonBloomAttention, self).__init__(config) def forward( self,