Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 209 additions & 0 deletions colossalai/kernel/triton/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import torch
from torch import nn

try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")

if HAS_TRITON:
from .qkv_matmul_kernel import qkv_gemm_4d_kernel
from .softmax_kernel import softmax_kernel

def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float):
r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
Args:
q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
v (torch.Tensor): V embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
input_mask (torch.Tensor): mask for softmax layer, shape should be (batch, num_heads, seq_lem, seq_len)
scale: the float scale value which is used to multiply with Q*K^T before doing softmax

Return:
output (Torch.Tensor): The output shape is (batch, seq_len, num_heads, head_size)
"""
assert len(q.shape) == 4, "the shape of q val must be 4"
batches, M, H, K = q.shape
assert q.shape == k.shape, "the shape of q and the shape of k must be equal"
assert q.shape == v.shape, "the shape of q and the shape of v must be equal"
assert q.shape[-1] == k.shape[-1], "the last dimension of q and k must be equal"

N = k.shape[1]

# head_size * num_of_head
d_model = q.shape[-1] * q.shape[-2]

score_output = torch.empty(
(batches, H, M, N), device=q.device, dtype=q.dtype)

grid = lambda meta: (
batches,
H,
triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)

qkv_gemm_4d_kernel[grid](
q, k, score_output,
M, N, K,
q.stride(0), q.stride(2), q.stride(1), q.stride(3),
k.stride(0), k.stride(2), k.stride(3), k.stride(1),
score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3),
scale=scale,
# currently manually setting, later on we can use auto-tune config to match best setting
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=32,
BLOCK_SIZE_K=32,
GROUP_SIZE_M=8,
)

softmax_output = torch.empty(
score_output.shape, device=score_output.device, dtype=score_output.dtype)
score_output_shape = score_output.shape

score_output = score_output.view(-1, score_output.shape[-1])
n_rows, n_cols = score_output.shape

if n_rows <= 350000:

block_size = max(triton.next_power_of_2(n_cols), 2)
num_warps = 4
if block_size >= 4096:
num_warps = 16
elif block_size >= 2048:
num_warps = 8
else:
num_warps = 4

softmax_kernel[(n_rows, )](
softmax_output,
score_output,
score_output.stride(0),
n_cols,
mask_ptr = input_mask,
num_warps=num_warps,
BLOCK_SIZE=block_size,
)

else:
#TODO: change softmax kernel functions to make it suitable for large size dimension
softmax_output = torch.nn.functional.softmax(score_output, dim=-1)
softmax_output = softmax_output.view(*score_output_shape)

batches, H, M, K = softmax_output.shape
N = v.shape[-1]

output = torch.empty(
(batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype)

grid = lambda meta: (
batches,
H,
triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)

qkv_gemm_4d_kernel[grid](
softmax_output, v, output,
M, N, K,
softmax_output.stride(0),
softmax_output.stride(1),
softmax_output.stride(2),
softmax_output.stride(3),
v.stride(0),
v.stride(2),
v.stride(1),
v.stride(3),
output.stride(0),
output.stride(2),
output.stride(1),
output.stride(3),
BLOCK_SIZE_M=128,
BLOCK_SIZE_N=64,
BLOCK_SIZE_K=64,
GROUP_SIZE_M=8,
scale=-1,
)
return output.view(batches, -1, d_model)


def self_attention_compute_using_triton(qkv,
input_mask,
layer_past,
alibi,
scale,
head_size,
triangular=False,
use_flash=False):

assert qkv.is_contiguous()
assert alibi is None, "current triton self-attention does not support alibi"
batches = qkv.shape[0]
d_model = qkv.shape[-1] // 3
num_of_heads = d_model // head_size

q = qkv[:, :, :d_model]
k = qkv[:, :, d_model:d_model * 2]
v = qkv[:, :, d_model * 2:]
q = q.view(batches, -1, num_of_heads, head_size)
k = k.view(batches, -1, num_of_heads, head_size)
v = v.view(batches, -1, num_of_heads, head_size)

data_output_triton = self_attention_forward_without_fusion(
q, k, v, input_mask, scale)

return data_output_triton


def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
if mask is not None:
assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask"
assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention"

hidden_dim = input.shape[-1]
output = torch.empty_like(input)
input = input.view(-1, hidden_dim)
if mask is not None:
mask = mask.view(-1, hidden_dim)
assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same"

num_rows, num_cols = input.shape
block_size = max(triton.next_power_of_2(num_cols), 2)
num_warps = 16
if block_size >= 4096:
num_warps = 16
elif block_size >= 2048:
num_warps = 8
else:
num_warps = 4

if num_rows <= 350000:
grid = (num_rows,)
softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps)
else:
grid = lambda meta: ()

grid = lambda meta: (
triton.cdiv(num_rows, meta["BLOCK_M"]),
)

BLOCK_M = 32
if block_size >= 4096:
BLOCK_M = 4
elif block_size >= 2048:
BLOCK_M = 8

softmax_kernel_2[grid](output_ptr = output,
input_ptr = input,
row_stride = input.stride(0),
n_rows = num_rows,
n_cols = num_cols,
mask_ptr = mask,
# currently manually setting up size
BLOCK_M = 32,
BLOCK_SIZE = block_size)

return output
109 changes: 109 additions & 0 deletions colossalai/kernel/triton/qkv_matmul_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")


if HAS_TRITON:
'''
this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
'''
@triton.jit
def qkv_gemm_4d_kernel(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
stride_ab,
stride_ah,
stride_am,
stride_ak,
stride_bb,
stride_bh,
stride_bk,
stride_bn,
stride_cb,
stride_ch,
stride_cm,
stride_cn,
scale,
# Meta-parameters
BLOCK_SIZE_M : tl.constexpr = 64,
BLOCK_SIZE_N : tl.constexpr = 32,
BLOCK_SIZE_K : tl.constexpr = 32,
GROUP_SIZE_M : tl.constexpr = 8,
):
r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer,
where score_matrix is softmax(Q*V^T/sqrt(hidden_size))
Args:
a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K)
b_ptr(torch.Tensor): pointer to input tensor array (bs, N, h, K) or (bs, h, N, K)
c_ptr(torch.Tensor): pointer to output tensor array (bs, M, h, N) or (bs, h, M, N)
stride_ab(tl.constexpr): stride for bs-dimention for tensor array A
stride_ah(tl.constexpr): stride for h-dimention for tensor array A
stride_am(tl.constexpr): stride for m-dimention for tensor array A
stride_ak(tl.constexpr): stride for k-dimention for tensor array A
stride_bb(tl.constexpr): stride for bs-dimention for tensor array B
stride_bh(tl.constexpr): stride for h-dimention for tensor array B
stride_bk(tl.constexpr): stride for k-dimention for tensor array B
stride_bn(tl.constexpr): stride for n-dimention for tensor array B
stride_cb(tl.constexpr): stride for bs-dimention for tensor array output
stride_ch(tl.constexpr): stride for h-dimention for tensor array output
stride_cm(tl.constexpr): stride for m-dimention for tensor array output
stride_cn(tl.constexpr): stride for n-dimention for tensor array output
BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a
BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b
BLOCK_SIZE_K : tiling size for K-dimension of a and b
GROUP_SIZE_M : group size for reducing cache miss, more details:
"""

num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
batch = tl.program_id(axis = 0)
head = tl.program_id(axis = 1)
pid = tl.program_id(axis = 2)

# the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m


offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah +
(offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak))
b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh +
(offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn))

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K)
b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N)
a = tl.load(a_ptrs, mask=a_mask, other=0.)
b = tl.load(b_ptrs, mask=b_mask, other=0.)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk

accumulator = accumulator.to(c_ptr.dtype.element_ty)
if scale > 0:
accumulator = accumulator * scale.to(c_ptr.dtype.element_ty)


offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] +
stride_cn * offs_accumu_n[None, :])
accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N)
tl.store(c_ptrs, accumulator, mask=accumulator_mask)
44 changes: 44 additions & 0 deletions colossalai/kernel/triton/softmax_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")

if HAS_TRITON:
'''
softmax kernel is modified based on
https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
'''
@triton.jit
def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):
r""" the kernel function for implementing softmax operator
Args:
output_ptr: the output after finishing softmax operation, (N, hidden_dim)
input_ptr: the tensor of input, shape should be (N, hidden_dim)
n_cols(tl.constexpr): the number of cols of input
BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim
"""
row_idx = tl.program_id(0)
row_start_ptr = input_ptr + row_idx * row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
row_minus_max = row - tl.max(row, axis=0)

if mask_ptr is not None:
# load mask into SRAM
mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)

# update
row_minus_max = row_minus_max + mask

numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
output_row_start_ptr = output_ptr + row_idx * row_stride
output_ptrs = output_row_start_ptr + col_offsets
# Write back output to DRAM
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
Loading