diff --git a/colossalai/gptq/__init__.py b/colossalai/gptq/__init__.py new file mode 100644 index 000000000000..0e0ee5152138 --- /dev/null +++ b/colossalai/gptq/__init__.py @@ -0,0 +1,7 @@ +from .cai_gptq import HAS_AUTO_GPTQ + +if HAS_AUTO_GPTQ: + from .cai_gptq import (gptq_fused_linear_triton, make_cai_quant_linear, + CaiQuantLinear, CaiGPTQLinearOp) + + diff --git a/colossalai/gptq/cai_gptq/__init__.py b/colossalai/gptq/cai_gptq/__init__.py new file mode 100644 index 000000000000..68addb8fb2f5 --- /dev/null +++ b/colossalai/gptq/cai_gptq/__init__.py @@ -0,0 +1,14 @@ +import warnings + +HAS_AUTO_GPTQ = False +try: + import auto_gptq + HAS_AUTO_GPTQ = True +except ImportError: + warnings.warn('please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ') + HAS_AUTO_GPTQ = False + +if HAS_AUTO_GPTQ: + from .gptq_triton import gptq_fused_linear_triton + from .cai_quant_linear import make_cai_quant_linear, CaiQuantLinear + from .gptq_op import CaiGPTQLinearOp diff --git a/colossalai/gptq/cai_gptq/cai_quant_linear.py b/colossalai/gptq/cai_gptq/cai_quant_linear.py new file mode 100644 index 000000000000..737b24462dc4 --- /dev/null +++ b/colossalai/gptq/cai_gptq/cai_quant_linear.py @@ -0,0 +1,131 @@ + +import math +import numpy as np +import torch +import torch.nn as nn +from .gptq_op import CaiGPTQLinearOp +import triton + +class CaiQuantLinear(nn.Module): + + def __init__(self, bits, groupsize, infeatures, outfeatures, bias): + super().__init__() + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize if groupsize != -1 else infeatures + + self.register_buffer('qweight', torch.zeros((infeatures // 64 * self.bits, outfeatures), dtype=torch.int64)) + self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 64 * self.bits), dtype=torch.int64)) + self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) + + if bias: + self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) + else: + self.bias = None + + self.gptq_linear = CaiGPTQLinearOp(groupsize, bits) + + + def pack(self, linear, scales, zeros, g_idx=None): + + g_idx = g_idx.clone() if g_idx is not None else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + half_scales = scales.clone().half() + # print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape) + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + wn = 16 + pbits = 64 + ptype = torch.int64 + unsign_type = np.uint64 + sign_type = np.int64 + + # wn = 8 + # pbits = 32 + # ptype = torch.int32 + # unsign_type = np.uint32 + # sign_type = np.int32 + + intweight = [] + for idx in range(self.infeatures): + intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(unsign_type) + qweight = np.zeros((intweight.shape[0] // pbits * self.bits, intweight.shape[1]), dtype=unsign_type) + + i = 0 + row = 0 + # print("weight shape ", intweight.shape, qweight.shape, out_qweight.shape, bits) + # print("weight shape ", intweight[0].shape, qweight[0].shape, out_qweight[0].shape) + # print("weight value ", intweight[0], qweight[0]) + + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (pbits // self.bits)): + qweight[row] |= intweight[j] << ( self.bits * (j - i)) + i += pbits // self.bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qweight = qweight.astype(sign_type) + qweight1 = torch.from_numpy(qweight) + qweight1 = qweight1.contiguous() #.to("cuda") + self.qweight.data.copy_(qweight1) + + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type) + zeros -= 1 + zeros = zeros.numpy().astype(unsign_type) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (pbits // self.bits)): + qzeros[:, col] |= zeros[:, j] << ( self.bits * (j - i)) + i += pbits // self.bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qzeros = qzeros.astype(sign_type) + qzeros = torch.from_numpy(qzeros) + qzeros = qzeros + self.qzeros.data.copy_(qzeros) + + if torch.equal(self.g_idx, g_idx): + self.g_idx = None + else: + self.g_idx = g_idx + + + def forward(self, x): + + cai_out = self.gptq_linear(x, + self.qweight, + self.scales, + self.qzeros, + g_idx = self.g_idx, + bias = self.bias,) + return cai_out + +def make_cai_quant_linear(module, names, bits, groupsize, name=''): + if isinstance(module, CaiQuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + delattr(module, attr) + setattr(module, attr, CaiQuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None)) + for name1, child in module.named_children(): + make_cai_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) + diff --git a/colossalai/gptq/cai_gptq/gptq_op.py b/colossalai/gptq/cai_gptq/gptq_op.py new file mode 100644 index 000000000000..aca1cb5b87c5 --- /dev/null +++ b/colossalai/gptq/cai_gptq/gptq_op.py @@ -0,0 +1,44 @@ +from .gptq_triton import gptq_fused_linear_triton +import torch + + +class CaiGPTQLinearOp(torch.nn.Module): + + def __init__(self, gptq_group_size, gptq_quant_bits): + super(CaiGPTQLinearOp, self).__init__() + self.group_size = gptq_group_size + self.bits = gptq_quant_bits + self.maxq = 2**self.bits - 1 + self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device()) + + def forward(self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zeros: torch.Tensor, + g_idx: torch.Tensor = None, + act_type = 0, + bias: torch.Tensor = None, + residual: torch.Tensor=None, + qkv_fused = False): + + add_bias = True + if bias is None: + bias = self.empty_tensor + add_bias = False + + add_residual = True + if residual is None: + residual = self.empty_tensor + add_residual = False + x = input.view(-1, input.shape[-1]) + + out = gptq_fused_linear_triton(x, weight, weight_scales, weight_zeros, bias, residual, + self.bits, self.maxq, self.group_size, qkv_fused, add_bias, add_residual, + act_type=act_type, g_idx=g_idx) + if qkv_fused: + out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1]) + else: + out = out.view(input.shape[0], input.shape[1], weight.shape[-1]) + + return out \ No newline at end of file diff --git a/colossalai/gptq/cai_gptq/gptq_triton.py b/colossalai/gptq/cai_gptq/gptq_triton.py new file mode 100644 index 000000000000..8a505ebad73f --- /dev/null +++ b/colossalai/gptq/cai_gptq/gptq_triton.py @@ -0,0 +1,467 @@ +import triton +import triton.language as tl +import torch +from auto_gptq.nn_modules.triton_utils import custom_autotune +# from ..ops.triton.kernels.activations_kernels import relu, gelu, silu +# code based https://github.com/fpgaminer/GPTQ-triton + # triton.Config({ + # 'BLOCK_SIZE_M': 32, + # 'BLOCK_SIZE_N': 32, + # 'BLOCK_SIZE_K': 128, + # 'GROUP_SIZE_M': 8 + # }, num_stages=2, num_warps=4), + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def cosh(x): + exp_x = tl.exp(x) + return (exp_x + 1.0 / exp_x) * 0.5 + + +# a Triton implementation of the most used activations +# See for instance http://arxiv.org/abs/1606.08415 for an overview + + +# ReLU +@triton.jit +def relu(x): + """ + ReLU_ activation function + + .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html + """ + return tl.where(x >= 0, x, 0.0) + + +@triton.jit +def squared_relu(x): + """ + Squared ReLU activation, as proposed in the Primer_ paper. + + .. _Primer: https://arxiv.org/abs/2109.08668 + """ + x_sq = x * x + return tl.where(x > 0.0, x_sq, 0.0) + + +@triton.jit +def star_relu(x): + """ + Star ReLU activation, as proposed in the "MetaFormer Baselines for Vision"_ paper. + + .. _ "MetaFormer Baselines for Vision": https://arxiv.org/pdf/2210.13452.pdf + """ + x_sq = x * x + return 0.8944 * tl.where(x > 0.0, x_sq, 0.0) - 0.4472 + + +# Leaky ReLU +@triton.jit +def leaky_relu(x): + """ + LeakyReLU_ activation + + .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html + """ + return tl.where(x >= 0.0, x, 0.01 * x) + + +@triton.jit +def gelu(x): + """ + GeLU_ activation - Gaussian error linear unit + + .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf + """ + return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x))) + + +@triton.jit +def smelu(x): + """ + SmeLU_ activation - Smooth ReLU with beta=2.0 + + .. _SmeLU: https://arxiv.org/pdf/2202.06499.pdf + """ + beta = 2.0 + + relu = tl.where(x >= beta, x, 0.0) + return tl.where( + tl.abs(x) <= beta, (x + beta) * (x + beta) / (4.0 * beta), relu) + + +@triton.jit +def silu(x): + return x*tl.sigmoid(x) + + +@custom_autotune.autotune( + configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=3, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], + nearest_power_of_two=True, + prune_configs_by={ + 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, + 'perf_model': None, + 'top_k': None, + }, +) +@triton.jit +def cai_gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_ptr, residual_ptr, + M, N, K, bits, maxq, gptq_group_size, + stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, + QKV_FUSED: tl.constexpr, ADD_BIAS: tl.constexpr, ADD_RESIDUAL:tl.constexpr, ACT_TYPE:tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//16, N) int64 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + """ + infearure_per_bits = 64 // bits + + pid = tl.program_id(axis=0) + NK = K + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K) + qkv_offset = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + # offs_bk = offs_k + qkv_offset * NK + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + qkv_offset * N * NK //infearure_per_bits + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + # g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + qkv_offset * NK * N //gptq_group_size + offs_bn[None, :] + zeros_ptrs = zeros_ptr + qkv_offset * NK * N //gptq_group_size//infearure_per_bits + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + g_idx_base = tl.arange(0, BLOCK_SIZE_K) + g_idx_base = g_idx_base // gptq_group_size + g_idx = g_idx_base + # tl.device_print("gidx, ", g_idx) + + currend_group_end = gptq_group_size + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + for k in range(0, num_pid_k): + # g_idx = tl.load(g_ptrs) + # if (k + 1) * BLOCK_SIZE_K > currend_group_end: + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros).to(tl.float16) * scales # Scale and shift + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_idx = g_idx_base + ((k + 1) * BLOCK_SIZE_K) // gptq_group_size + # if (k + 2) * BLOCK_SIZE_K > currend_group_end: + + c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + + if ADD_BIAS: + bias_mask = (offs_bn < N) + offs_bn += qkv_offset * N + bias_ptrs = bias_ptr + stride_cn * offs_bn + bias = tl.load(bias_ptrs, mask=bias_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + accumulator += bias[None, :] + + + if ACT_TYPE == 1: + accumulator=relu(accumulator) + elif ACT_TYPE == 2: + accumulator=gelu(accumulator) + elif ACT_TYPE == 3: + accumulator=silu(accumulator) + + + if ADD_RESIDUAL: + residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + res = tl.load(residual_ptrs, mask=c_mask, other=0.) + accumulator += res + + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@custom_autotune.autotune( + configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=3, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], + nearest_power_of_two=True, + prune_configs_by={ + 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, + 'perf_model': None, + 'top_k': None, + }, +) +@triton.jit +def cai_gptq_idx_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, idx_ptr, bias_ptr, residual_ptr, + M, N, K, bits, maxq, gptq_group_size, + stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, + QKV_FUSED: tl.constexpr, ADD_BIAS: tl.constexpr, ADD_RESIDUAL:tl.constexpr, ACT_TYPE:tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//16, N) int64 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + """ + infearure_per_bits = 64 // bits + + pid = tl.program_id(axis=0) + NK = K + + # if QKV_FUSED: + # NK = K//3 + # else: + # NK = K + # NK = K + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K) + qkv_offset = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + # offs_bk = offs_k + qkv_offset * NK + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + qkv_offset * N * NK //infearure_per_bits + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + # g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + qkv_offset * NK * N //gptq_group_size + offs_bn[None, :] + zeros_ptrs = zeros_ptr + qkv_offset * NK * N //gptq_group_size//infearure_per_bits + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + g_ptrs = idx_ptr + offs_bk + g_idx = tl.load(g_ptrs) + # tl.device_print("gidx, ", g_idx) + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + + currend_group_end = gptq_group_size + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + + for k in range(0, num_pid_k): + # g_idx = tl.load(g_ptrs) + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros).to(tl.float16) * scales # Scale and shift + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + + if ADD_BIAS: + bias_mask = (offs_bn < N) + offs_bn += qkv_offset * N + bias_ptrs = bias_ptr + stride_cn * offs_bn + bias = tl.load(bias_ptrs, mask=bias_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + accumulator += bias[None, :] + + + if ACT_TYPE == 1: + accumulator=relu(accumulator) + elif ACT_TYPE == 2: + accumulator=gelu(accumulator) + elif ACT_TYPE == 3: + accumulator=silu(accumulator) + + + if ADD_RESIDUAL: + residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + res = tl.load(residual_ptrs, mask=c_mask, other=0.) + accumulator += res + + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def gptq_fused_linear_triton(input, qweight, scales, qzeros, bias, residual, + bits, maxq, gptq_group_size, qkv_fused, add_bias, add_residual, g_idx = None, act_type = 0): + # print("gptq fused ", qkv_fused, add_bias, add_residual) + with torch.cuda.device(input.device): + if qkv_fused: + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']) * 3, ) + output = torch.empty((input.shape[0]*3, qweight.shape[1]), device=input.device, dtype=torch.float16) + else: + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), ) + output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) + # print("dtype, ", qweight.dtype, output.dtype, scales.dtype, qzeros.dtype, bias.dtype, residual.dtype) + if g_idx is None: + cai_gptq_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, bias, residual, + input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, + gptq_group_size, + input.stride(0), input.stride(1), qweight.stride(0), + qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0), + QKV_FUSED=qkv_fused, ADD_BIAS=add_bias, ADD_RESIDUAL=add_residual, ACT_TYPE=act_type) + else: + cai_gptq_idx_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, bias, residual, + input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, + gptq_group_size, + input.stride(0), input.stride(1), qweight.stride(0), + qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0), + QKV_FUSED=qkv_fused, ADD_BIAS=add_bias, ADD_RESIDUAL=add_residual, ACT_TYPE=act_type) + if qkv_fused: + return output.view(3, input.shape[0], qweight.shape[1]) + else: + return output diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index e65271621ddd..657bd3eb28d8 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -16,4 +16,5 @@ triton==2.0.0.dev20221202 requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 SentencePiece ninja -flash_attn>=2.0 +flash_attn==2.0.5 +#auto-gptq now not support torch1.12 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 65eecce2c34f..f6be6a624c70 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -10,5 +10,4 @@ contexttimer ninja torch>=1.11 safetensors -flash_attn>=2.0 einops diff --git a/tests/test_gptq/test_linear_act_fusion.py b/tests/test_gptq/test_linear_act_fusion.py new file mode 100644 index 000000000000..4540d990dc3a --- /dev/null +++ b/tests/test_gptq/test_linear_act_fusion.py @@ -0,0 +1,309 @@ +import torch +import torch.nn as nn +import pytest +import time +import transformers +from auto_gptq.quantization import GPTQ +from auto_gptq.modeling._utils import find_layers, pack_model +from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear + +from auto_gptq.quantization.quantizer import Quantizer +from colossalai.gptq import CaiGPTQLinearOp +import math +import numpy as np + + +wbits=4 +trits=False +nsamples=1 +percdamp=.01 +groupsize=128 +act_order=False +sym=False +class MLinear(nn.Module): + def __init__(self, infeature, outfeature): + super(MLinear, self).__init__() + self.linear = torch.nn.Linear(infeature, outfeature, dtype=torch.float16) + def forward(self, x): + out = self.linear(x) + return out + +@torch.no_grad() +def model_quant(model, inps, dev): + print('Starting ...') + layers = [model] + layers[0] = layers[0].to(dev) + + dtype = next(iter(model.parameters())).dtype + cache = {'i': 0} + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + def forward(self, inp, **kwargs): + inps[cache['i']] = inp + cache['i'] += 1 + raise ValueError + layers[0] = Catcher(layers[0]) + # for batch in inps: + try: + model(inps.to(dev)) + except ValueError: + pass + layers[0] = layers[0].module + + outs = torch.zeros(inps.shape[0], layers[0].linear.weight.shape[0]) + + print('Ready.') + + quantizers = {} + for i in range(len(layers)): + layer = layers[i].to(dev) + subset = find_layers(layer) + gptq = {} + for name in subset: + gptq[name] = GPTQ(subset[name]) + # gptq[name].quantizer = Quantizer() + gptq[name].quantizer.configure(wbits, perchannel=True, sym=sym, mse=False, trits=trits) + + def add_batch(name): + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + return tmp + + handles = [] + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + + for j in range(nsamples): + outs[j] = layer(inps[j].unsqueeze(0))[0] + + for h in handles: + h.remove() + for name in subset: + print(f'Quantizing {name} in layer {i+1}/{len(layers)}...') + scale,zero,g_idx = gptq[name].fasterquant(percdamp=percdamp, group_size=groupsize, actorder=act_order) + # quantizers['%s' % (name)] = (gptq[name].quantizer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) + quantizers['%s' % (name)] = (gptq[name].layer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) + + gptq[name].free() + for j in range(nsamples): + layer = layer.to(dev) + outs[j] = layer(inps[j].unsqueeze(0))[0] + + layers[i] = layer.cpu() + del layer + del gptq + torch.cuda.empty_cache() + + inps, outs = outs, inps + + return quantizers + + +def model_pack(model, quantizers, wbits, groupsize): + pack_model(model, quantizers, wbits, groupsize) + return model + + +def cai_linear_pack(linear, scales, zeros, + out_qweight, out_qscales, out_qzeros, qg_idx, + infeatures, groupsize, bits): + g_idx = qg_idx.clone() if qg_idx is not None else torch.tensor([i // groupsize for i in range(infeatures)], dtype=torch.int32) + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + half_scales = scales.clone().half() + # print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape) + + out_qscales.data.copy_(scales) + + wn = 16 + pbits = 64 + ptype = torch.int64 + unsign_type = np.uint64 + sign_type = np.int64 + + # wn = 8 + # pbits = 32 + # ptype = torch.int32 + # unsign_type = np.uint32 + # sign_type = np.int32 + + intweight = [] + for idx in range(infeatures): + intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(unsign_type) + qweight = np.zeros((intweight.shape[0] // pbits * bits, intweight.shape[1]), dtype=unsign_type) + + i = 0 + row = 0 + # print("weight shape ", intweight.shape, qweight.shape, out_qweight.shape, bits) + # print("weight shape ", intweight[0].shape, qweight[0].shape, out_qweight[0].shape) + # print("weight value ", intweight[0], qweight[0]) + + while row < qweight.shape[0]: + if bits in [2, 4, 8]: + for j in range(i, i + (pbits // bits)): + qweight[row] |= intweight[j] << (bits * (j - i)) + i += pbits // bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qweight = qweight.astype(sign_type) + qweight1 = torch.from_numpy(qweight) + qweight1 = qweight1.contiguous().to("cuda") + out_qweight.data.copy_(qweight1) + + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * bits), dtype=unsign_type) + zeros -= 1 + zeros = zeros.numpy().astype(unsign_type) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if bits in [2, 4, 8]: + for j in range(i, i + (pbits // bits)): + qzeros[:, col] |= zeros[:, j] << (bits * (j - i)) + i += pbits // bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qzeros = qzeros.astype(sign_type) + qzeros = torch.from_numpy(qzeros) + qzeros = qzeros.to("cuda") + out_qzeros.data.copy_(qzeros) + + return out_qweight, out_qscales, out_qzeros + +def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize): + layers = find_layers(model) + layers = {n: layers[n] for n in quantizers} + with torch.no_grad(): + for name in layers: + _, scale, zero, g_idx = quantizers[name] + qweight, qscales, qzeros = cai_linear_pack(layers[name], scale, zero, + qweight, qscales, qzeros, g_idx, + layers[name].weight.shape[-1], groupsize, wbits) + + # print("cai pack", layers) + return qweight, qscales, qzeros + + +def test_gptq_linear(): + + infeature = 5120 + outfeature = 5120 + + weight = torch.randn(outfeature, infeature).to(torch.float16).to(torch.cuda.current_device()) + bias = torch.zeros(outfeature).to(torch.float16).to(torch.cuda.current_device()) + wn = 16 + ptype = torch.int64 + + # wn = 8 + # ptype = torch.int32 + + qweight = torch.zeros(infeature//wn, outfeature, dtype=ptype, device=torch.cuda.current_device()).contiguous() + qscales = torch.zeros(infeature//groupsize, outfeature, dtype=torch.float16, device=torch.cuda.current_device()).contiguous() + qzeros = torch.zeros(infeature//groupsize, outfeature//wn, dtype=ptype, device=torch.cuda.current_device()).contiguous() + + act_func = nn.SiLU() + inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device()) + batch_inps = torch.randn(1, 4096, infeature).to(torch.float16).to(torch.cuda.current_device()) + + linear = MLinear(infeature, outfeature) + linear.to(torch.cuda.current_device()) + + with torch.no_grad(): + linear.linear.weight.data.copy_(weight) + linear.linear.bias.data.copy_(bias) + + with torch.no_grad(): + torch_out = linear(inps) + batch_torch_out = linear(batch_inps) + torch_out = act_func(torch_out) + batch_torch_out = act_func(batch_torch_out) + + + # linear.to("cuda") + quantizers = model_quant(linear, inps, torch.cuda.current_device()) + qweight, qscales, qzeros = model_cai_pack(linear, quantizers, qweight, qscales, qzeros, wbits, groupsize) + gptq_model = model_pack(linear, quantizers, wbits, groupsize) + gptq_model.to(torch.cuda.current_device()) + # gptq_model = linear + + + cai_linear = CaiGPTQLinearOp(groupsize, wbits) + + + # qweight = torch.cat((qweight, qweight, qweight), dim=0).contiguous() + # qscales = torch.cat((qscales, qscales, qscales), dim=0).contiguous() + # qzeros = torch.cat((qzeros, qzeros, qzeros), dim=0).contiguous() + # bias = torch.cat((bias, bias, bias), dim=0).contiguous() + qkv_fused=False + + with torch.no_grad(): + gptq_out = gptq_model(inps) + batch_gptq_out = gptq_model(batch_inps) + cai_out = cai_linear(inps, + qweight, + qscales, + qzeros, + bias = bias, + act_type = 3, + qkv_fused=qkv_fused) + torch.cuda.synchronize() + + batch_cai_out = cai_linear(batch_inps, + qweight, + qscales, + qzeros, + bias=bias, + act_type = 3, + qkv_fused=qkv_fused) + torch.cuda.synchronize() + batch_gptq_out = act_func(batch_gptq_out) + gptq_out = act_func(gptq_out) + + # cai_out = cai_out[1] + # batch_cai_out = batch_cai_out[1] + # a = torch.sum(qscales, 0) + # print("qscales ", a) + # print("orch out ", torch_out) + # print("gptq out ", gptq_out) + # print("cai out ", cai_out) + # # print("batch_torch out ", batch_torch_out) + + # print("batch_torch out ", batch_torch_out) + # print("batch_gptq out ", batch_gptq_out) + # print("batch_cai out ", batch_cai_out) + + assert torch.allclose(cai_out, gptq_out, rtol=1e-01, atol=1e-02) + assert torch.allclose(batch_cai_out, batch_gptq_out, rtol=1e-01, atol=1e-02) + + + # mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) + # max_diff = torch.max(torch.abs(cai_out - gptq_out)) + # print("cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + # mean_diff = torch.mean(torch.abs(torch_out - gptq_out)) + # max_diff = torch.max(torch.abs(torch_out - gptq_out)) + # print("torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + # mean_diff = torch.mean(torch.abs(torch_out - cai_out)) + # max_diff = torch.max(torch.abs(torch_out - cai_out)) + # print("torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + + # mean_diff = torch.mean(torch.abs(batch_cai_out - batch_gptq_out)) + # max_diff = torch.max(torch.abs(batch_cai_out - batch_gptq_out)) + # print("batch cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + # mean_diff = torch.mean(torch.abs(batch_torch_out - batch_gptq_out)) + # max_diff = torch.max(torch.abs(batch_torch_out - batch_gptq_out)) + # print("batch torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + # mean_diff = torch.mean(torch.abs(batch_torch_out - batch_cai_out)) + # max_diff = torch.max(torch.abs(batch_torch_out - batch_cai_out)) + # print("batch torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + +if __name__ == "__main__": + + test_gptq_linear()