From aaa8c4f4595b551b1c3a9c1e1cb7370681050126 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 23 Aug 2023 09:33:06 +0800 Subject: [PATCH 1/5] [skip ci] add cuda kernels --- colossalai/gptq/cai_gptq/cai_quant_linear.py | 120 +++++-- colossalai/gptq/cai_gptq/gptq_triton.py | 13 +- .../cuda_native/csrc/gptq/column_remap.cu | 63 ++++ .../cuda_native/csrc/gptq/column_remap.cuh | 19 ++ .../cuda_native/csrc/gptq/cu_compat.cuh | 58 ++++ .../cuda_native/csrc/gptq/cuda_buffers.cu | 75 +++++ .../cuda_native/csrc/gptq/cuda_buffers.cuh | 55 ++++ .../cuda_native/csrc/gptq/hip_compat.cuh | 49 +++ .../cuda_native/csrc/gptq/linear_gptq.cpp | 254 +++++++++++++++ .../kernel/cuda_native/csrc/gptq/matrix.cuh | 294 ++++++++++++++++++ .../kernel/cuda_native/csrc/gptq/q4_matmul.cu | 260 ++++++++++++++++ .../cuda_native/csrc/gptq/q4_matmul.cuh | 43 +++ .../kernel/cuda_native/csrc/gptq/q4_matrix.cu | 225 ++++++++++++++ .../cuda_native/csrc/gptq/q4_matrix.cuh | 53 ++++ .../kernel/cuda_native/csrc/gptq/tuning.h | 13 + .../kernel/cuda_native/csrc/gptq/util.cuh | 33 ++ op_builder/gptq.py | 52 ++++ ...near_act_fusion.py => test_gptq_linear.py} | 105 +++---- 18 files changed, 1697 insertions(+), 87 deletions(-) create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/tuning.h create mode 100644 colossalai/kernel/cuda_native/csrc/gptq/util.cuh create mode 100644 op_builder/gptq.py rename tests/test_gptq/{test_linear_act_fusion.py => test_gptq_linear.py} (80%) diff --git a/colossalai/gptq/cai_gptq/cai_quant_linear.py b/colossalai/gptq/cai_gptq/cai_quant_linear.py index 737b24462dc4..16285dc17e29 100644 --- a/colossalai/gptq/cai_gptq/cai_quant_linear.py +++ b/colossalai/gptq/cai_gptq/cai_quant_linear.py @@ -1,3 +1,4 @@ +# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ import math import numpy as np @@ -5,9 +6,27 @@ import torch.nn as nn from .gptq_op import CaiGPTQLinearOp import triton +import warnings + +HAS_GPTQ_CUDA = False +try: + from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() + HAS_GPTQ_CUDA = True +except ImportError: + warnings.warn('CUDA gptq is not installed') + HAS_GPTQ_CUDA = False + class CaiQuantLinear(nn.Module): - + max_dq_buffer_size=1 + max_inner_outer_dim=1 + max_input_len=1 + prepared_buffers=False + device_to_buffers = { + "temp_state": None, + "temp_dq": None, + } def __init__(self, bits, groupsize, infeatures, outfeatures, bias): super().__init__() if bits not in [2, 4, 8]: @@ -18,8 +37,8 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias): self.maxq = 2**self.bits - 1 self.groupsize = groupsize if groupsize != -1 else infeatures - self.register_buffer('qweight', torch.zeros((infeatures // 64 * self.bits, outfeatures), dtype=torch.int64)) - self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 64 * self.bits), dtype=torch.int64)) + self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) @@ -30,6 +49,8 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias): self.gptq_linear = CaiGPTQLinearOp(groupsize, bits) + self.q4 = None + self.empty_tensor = torch.empty((1, 1), device="meta") def pack(self, linear, scales, zeros, g_idx=None): @@ -44,17 +65,17 @@ def pack(self, linear, scales, zeros, g_idx=None): if linear.bias is not None: self.bias = linear.bias.clone().half() - wn = 16 - pbits = 64 - ptype = torch.int64 - unsign_type = np.uint64 - sign_type = np.int64 + # wn = 16 + # pbits = 64 + # ptype = torch.int64 + # unsign_type = np.uint64 + # sign_type = np.int64 - # wn = 8 - # pbits = 32 - # ptype = torch.int32 - # unsign_type = np.uint32 - # sign_type = np.int32 + wn = 8 + pbits = 32 + ptype = torch.int32 + unsign_type = np.uint32 + sign_type = np.int32 intweight = [] for idx in range(self.infeatures): @@ -101,21 +122,75 @@ def pack(self, linear, scales, zeros, g_idx=None): qzeros = qzeros self.qzeros.data.copy_(qzeros) - if torch.equal(self.g_idx, g_idx): + if torch.equal(self.g_idx.to(g_idx.device), g_idx): self.g_idx = None else: self.g_idx = g_idx + CaiQuantLinear.max_dq_buffer_size = max(CaiQuantLinear.max_dq_buffer_size, self.qweight.numel() * 8) - def forward(self, x): + if self.g_idx is not None: + CaiQuantLinear.max_inner_outer_dim = max(CaiQuantLinear.max_inner_outer_dim, self.infeatures, self.outfeatures) + max_input_len=4096 + + + def prepare_buffers(self): + assert self.qweight.device.type == "cuda" + device = self.qweight.device - cai_out = self.gptq_linear(x, - self.qweight, - self.scales, - self.qzeros, - g_idx = self.g_idx, - bias = self.bias,) - return cai_out + # The temp_state buffer is required to reorder X in the act-order case. + # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. + CaiQuantLinear.device_to_buffers['temp_state'] = torch.zeros((CaiQuantLinear.max_input_len, CaiQuantLinear.max_inner_outer_dim), dtype=torch.float16, device=device) + CaiQuantLinear.device_to_buffers['temp_dp'] = torch.zeros((1, CaiQuantLinear.max_dq_buffer_size), dtype=torch.float16, device=device) + + gptq_cuda.prepare_buffers(torch.device(device), CaiQuantLinear.device_to_buffers['temp_state'], CaiQuantLinear.device_to_buffers['temp_dp']) + + # Using the default from exllama repo here. + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + torch.cuda.empty_cache() + def init_q4(self): + assert self.qweight.device.type == "cuda" + self.q4_width = self.qweight.shape[1] + if self.g_idx is not None: + g_idx = self.g_idx.to("cpu") + else: + g_idx = self.empty_tensor + + self.q4 = gptq_cuda.make_q4(self.qweight, + self.qzeros, + self.scales, + g_idx, + torch.cuda.current_device()) + torch.cuda.synchronize() + + def forward(self, x): + outshape = x.shape[:-1] + (self.outfeatures,) + + if HAS_GPTQ_CUDA: + if CaiQuantLinear.prepared_buffers == False: + self.prepare_buffers() + CaiQuantLinear.prepared_buffers = True + + if self.q4 is None: + self.init_q4() + + x = x.view(-1, x.shape[-1]) + output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device) + gptq_cuda.q4_matmul(x, self.q4, output) + if self.bias is not None: + output.add_(self.bias) + else: + output = self.gptq_linear(x, + self.qweight, + self.scales, + self.qzeros, + g_idx = self.g_idx, + bias = self.bias,) + return output.view(outshape) def make_cai_quant_linear(module, names, bits, groupsize, name=''): if isinstance(module, CaiQuantLinear): @@ -128,4 +203,3 @@ def make_cai_quant_linear(module, names, bits, groupsize, name=''): setattr(module, attr, CaiQuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None)) for name1, child in module.named_children(): make_cai_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) - diff --git a/colossalai/gptq/cai_gptq/gptq_triton.py b/colossalai/gptq/cai_gptq/gptq_triton.py index 8a505ebad73f..def711d1e6c4 100644 --- a/colossalai/gptq/cai_gptq/gptq_triton.py +++ b/colossalai/gptq/cai_gptq/gptq_triton.py @@ -168,12 +168,12 @@ def cai_gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_ """ Compute the matrix multiplication C = A x B. A is of shape (M, K) float16 - B is of shape (K//16, N) int64 + B is of shape (K//8, N) int32 C is of shape (M, N) float16 scales is of shape (G, N) float16 zeros is of shape (G, N) float16 """ - infearure_per_bits = 64 // bits + infearure_per_bits = 32 // bits pid = tl.program_id(axis=0) NK = K @@ -334,12 +334,12 @@ def cai_gptq_idx_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, i """ Compute the matrix multiplication C = A x B. A is of shape (M, K) float16 - B is of shape (K//16, N) int64 + B is of shape (K//8, N) int32 C is of shape (M, N) float16 scales is of shape (G, N) float16 zeros is of shape (G, N) float16 """ - infearure_per_bits = 64 // bits + infearure_per_bits = 32 // bits pid = tl.program_id(axis=0) NK = K @@ -439,6 +439,11 @@ def cai_gptq_idx_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, i def gptq_fused_linear_triton(input, qweight, scales, qzeros, bias, residual, bits, maxq, gptq_group_size, qkv_fused, add_bias, add_residual, g_idx = None, act_type = 0): # print("gptq fused ", qkv_fused, add_bias, add_residual) + assert input.is_cuda, "input is not in cuda" + assert qweight.is_cuda, "qweight is not in cuda" + assert scales.is_cuda, "scales is not in cuda" + assert qzeros.is_cuda, "qzeros is not in cuda" + with torch.cuda.device(input.device): if qkv_fused: grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']) * 3, ) diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu new file mode 100644 index 000000000000..2b1b366b1c02 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu @@ -0,0 +1,63 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "column_remap.cuh" +#include "util.cuh" + +const int SHUF_BLOCKSIZE_X = 256; +const int SHUF_BLOCKSIZE_Y = 16; + +__global__ void column_remap_kernel +( + const half* __restrict__ x, + half* __restrict__ x_new, + const int x_width, + const int x_height, + const uint32_t* x_map +) +{ + int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; + if (x_column >= x_width) return; + //if (x_row >= x_height) return; + + int x_stride = x_width; + int x_idx = x_row * x_stride + x_column; + + int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); + int x_idx_end = x_row_end * x_stride + x_column; + + int s_column = x_map[x_column]; + int s_idx = x_row * x_stride + s_column; + + while (x_idx < x_idx_end) + { + x_new[x_idx] = x[s_idx]; + x_idx += x_stride; + s_idx += x_stride; + } +} + +// Remap columns in x to correspond to sequential group index before matmul +// +// perform x -> seq_x such that seq_x @ seq_w == x @ w + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +) +{ + dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); + + dim3 blocks + ( + (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, + (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, + 1 + ); + + column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh new file mode 100644 index 000000000000..6571c17d6fd5 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh @@ -0,0 +1,19 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _column_remap_cuh +#define _column_remap_cuh + +#include +#include +#include + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +); + +#endif \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh new file mode 100644 index 000000000000..c5258813e147 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh @@ -0,0 +1,58 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_compat_cuh +#define _cuda_compat_cuh + +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) +{ + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do + { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } + while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) +{ + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do + { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } + while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } + +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } +#endif + +#endif +#endif + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu new file mode 100644 index 000000000000..4416027c8387 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu @@ -0,0 +1,75 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#define _cuda_buffers_cu +#include "cuda_buffers.cuh" + +CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; +// __constant__ half2 q4_table[16][256]; +// half2 q4_table_host[16][256]; +// bool q4_table_init = false; + +CudaBuffers::CudaBuffers +( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq +) : + device(_device), + temp_state_size(_temp_state_size), + temp_state(_temp_state), + temp_dq(_temp_dq) +{ + cudaSetDevice(_device); + + cudaStreamCreate(&alt_stream_1); + cudaStreamCreate(&alt_stream_2); + cudaStreamCreate(&alt_stream_3); + cudaEventCreate(&alt_stream_1_done); + cudaEventCreate(&alt_stream_2_done); + cudaEventCreate(&alt_stream_3_done); +} + +CudaBuffers::~CudaBuffers() +{ + cudaStreamDestroy(alt_stream_1); + cudaStreamDestroy(alt_stream_2); + cudaStreamDestroy(alt_stream_3); + cudaEventDestroy(alt_stream_1_done); + cudaEventDestroy(alt_stream_2_done); + cudaEventDestroy(alt_stream_3_done); +} + +CudaBuffers* get_buffers(const int device_index) +{ + return g_buffers[device_index]; +} + +void prepare_buffers_cuda +( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq +) +{ + CudaBuffers* buffers = new CudaBuffers + ( + _device, + _temp_state_size, + _temp_state, + _temp_dq + ); + + g_buffers[_device] = buffers; +} + +void cleanup_buffers_cuda() +{ + for (int i = 0; i < CUDA_MAX_DEVICES; i++) + { + if (!g_buffers[i]) continue; + delete g_buffers[i]; + g_buffers[i] = NULL; + } +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh new file mode 100644 index 000000000000..0bf2057c665c --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh @@ -0,0 +1,55 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_buffers_cuh +#define _cuda_buffers_cuh + +#include +#include +#include +#include + +const int CUDA_MAX_DEVICES = 16; + +// #ifndef _cuda_buffers_cu +// extern __constant__ half2 q4_table[16][256]; +// #endif + +class CudaBuffers +{ +public: + int device; + + half* temp_state; // [max_hidden_rows * intermediate_size] + int temp_state_size; + half* temp_dq; // size of largest quant tensor * 8 + + cudaStream_t alt_stream_1; + cudaStream_t alt_stream_2; + cudaStream_t alt_stream_3; + cudaEvent_t alt_stream_1_done; + cudaEvent_t alt_stream_2_done; + cudaEvent_t alt_stream_3_done; + + CudaBuffers + ( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq + ); + ~CudaBuffers(); +}; + +CudaBuffers* get_buffers(const int device_index); + +void prepare_buffers_cuda +( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq +); + +void cleanup_buffers_cuda(); + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh new file mode 100644 index 000000000000..5cd2e8553ef6 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh @@ -0,0 +1,49 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _hip_compat_cuh +#define _hip_compat_cuh + +// Workaround for a bug in hipamd, backported from upstream. +__device__ __forceinline__ __half __compat_hrcp(__half x) { + return __half_raw{ + static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; +} + +__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { + return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)), + static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))}; +} + +#define hrcp __compat_hrcp +#define h2rcp __compat_h2rcp + +// Workaround for hipify_python using rocblas instead of hipblas. +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, + hipblasOperation_t transA, + hipblasOperation_t transB, + int m, + int n, + int k, + const half* alpha, + const half* AP, + int lda, + const half* BP, + int ldb, + const half* beta, + half* CP, + int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} + +#define rocblas_handle hipblasHandle_t +#define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_get_stream hipblasGetStream +#define rocblas_set_stream hipblasSetStream +#define rocblas_hgemm __compat_hipblasHgemm + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp new file mode 100644 index 000000000000..bcc0e43901de --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp @@ -0,0 +1,254 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include +#include +#include +#include +#include +#include +#include +#include "util.cuh" +#include "tuning.h" +#include "cuda_buffers.cuh" +#include "q4_matrix.cuh" +#include "q4_matmul.cuh" +#include "column_remap.cuh" + +// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a +// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of +// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. + +void check_cuda(cudaError_t ret) +{ + switch (ret) + { + case cudaSuccess: + break; + + case cudaUnspecified: + printf(" **** Unspecified error\n"); + TORCH_CHECK(false, "CUDA error"); + break; + + default: + printf(" **** CUDA error\n"); \ + printf(" **** %s\n", cudaGetErrorString(ret)); \ + TORCH_CHECK(false, "CUDA error"); \ + break; + } +} + +// Some decluttering macros + +#define STRINGIFY_(__x) #__x +#define STRINGIFY(__x) STRINGIFY_(__x) +#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) +#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") + +#define TORCH_CHECK_DEVICE_INDEX(__index) \ +do { \ + TORCH_CHECK(__index >= 0, "no device index"); \ + TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ +} while(0) + +#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ +do { \ + TORCH_CHECK_DTYPE(__w, kInt); \ + TORCH_CHECK_DTYPE(__w_scales, kHalf); \ + TORCH_CHECK_DTYPE(__w_zeros, kInt); \ + TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ + TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ + TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ + TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ +} while(0) + +int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) +{ + int groupsize = w.size(0) * 8 / w_zeros.size(0); + TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") + return groupsize; +} + + +// Tuning parameters + +ExLlamaTuning tuningParams; + +void set_tuning_params +( + int matmul_recons_thd, + bool matmul_fused_remap, + bool matmul_no_half2 +) +{ + tuningParams.matmul_recons_thd = matmul_recons_thd; + tuningParams.matmul_fused_remap = matmul_fused_remap; + tuningParams.matmul_no_half2 = matmul_no_half2; +} + + +// Release all unmanaged objects allocated by the extension + +void cleanup() +{ + cleanup_buffers_cuda(); + g_q4_free_matrices(); +} + + +// Prepare buffers for forward pass + +void prepare_buffers +( + torch::Device device, + torch::Tensor temp_state, + torch::Tensor temp_dq +) +{ + int device_index = device.index(); + TORCH_CHECK_DEVICE_INDEX(device_index); + const at::cuda::OptionalCUDAGuard device_guard(device); + + prepare_buffers_cuda + ( + device_index, + // buffer size used for sanity checks + temp_state.numel(), + (half*) temp_state.data_ptr(), + (half*) temp_dq.data_ptr() + ); +} + + +// Create Q4Matrix, return handle + +uintptr_t make_q4 +( + torch::Tensor qweight, + torch::Tensor qzeros, + torch::Tensor scales, + torch::Tensor g_idx, + int device +) +{ + TORCH_CHECK_DTYPE(qweight, kInt); + TORCH_CHECK_DTYPE(qzeros, kInt); + TORCH_CHECK_DTYPE(scales, kHalf); + TORCH_CHECK_DTYPE_OPT(g_idx, kInt); + TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); + TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); + TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); + + int width = qweight.size(1); + int height = qweight.size(0) * 8; + int groups = qzeros.size(0); + + Q4Matrix* m = new Q4Matrix + ( + height, + width, + groups, + + (uint32_t*) qweight.data_ptr(), + (uint32_t*) qzeros.data_ptr(), + (half*) scales.data_ptr(), + g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), + + device + ); + + g_q4_keep_matrix(m); + return reinterpret_cast (m); +} + + +// Matmul half @ quant -> half + +void q4_matmul +( + torch::Tensor x, + uintptr_t w, + torch::Tensor out +) +{ + Q4Matrix* wm = reinterpret_cast (w); + + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(out, kHalf); + TORCH_CHECK_SHAPES(x, 0, out, 0, 1); + TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + int x_height = x.size(0); + + if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) + { + q4_matmul_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr() + ); + } + else + { + q4_matmul_recons_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr(), + at::cuda::getCurrentCUDABlasHandle() + ); + } +} + + +// Remap columns in half tensor + +void column_remap +( + torch::Tensor x, + torch::Tensor x_new, + torch::Tensor x_map +) +{ + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(x_new, kHalf); + TORCH_CHECK_DTYPE(x_map, kInt); + TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); + + int height = x.size(0); + int width = x.size(1); + + TORCH_CHECK_BUFFER_SIZE(x_new, height * width); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + column_remap_cuda + ( + (half*) x.data_ptr(), + (half*) x_new.data_ptr(), + height, + width, + (uint32_t*) x_map.data_ptr() + ); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); + m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); + m.def("cleanup", &cleanup, "cleanup"); + m.def("make_q4", &make_q4, "make_q4"); + m.def("q4_matmul", &q4_matmul, "q4_matmul"); +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh new file mode 100644 index 000000000000..2fd5ab0b36cd --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh @@ -0,0 +1,294 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _matrix_cuh +#define _matrix_cuh + +#include +#include + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } +}; + +class MatrixView_q4_column +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } +}; + +// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale + +__device__ __forceinline__ half2 dot_product_8 +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + +// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently) +// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff]; +// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff]; +// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ]; + + half2 tmp = __hmul2(*h_ptr++, v_01); + tmp = __hfma2(*h_ptr++, v_23, tmp); + tmp = __hfma2(*h_ptr++, v_45, tmp); + tmp = __hfma2(*h_ptr++, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half* h_ptr = h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(*h_ptr++, v_0); + tmp = __hfma(*h_ptr++, v_1, tmp); + tmp = __hfma(*h_ptr++, v_2, tmp); + tmp = __hfma(*h_ptr++, v_3, tmp); + tmp = __hfma(*h_ptr++, v_4, tmp); + tmp = __hfma(*h_ptr++, v_5, tmp); + tmp = __hfma(*h_ptr++, v_6, tmp); + tmp = __hfma(*h_ptr++, v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map + +__device__ __forceinline__ half2 dot_product_8_x_map +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + + half h_0 = h_ptr[*x_map_ptr++]; + half h_1 = h_ptr[*x_map_ptr++]; + half h_2 = h_ptr[*x_map_ptr++]; + half h_3 = h_ptr[*x_map_ptr++]; + half h_4 = h_ptr[*x_map_ptr++]; + half h_5 = h_ptr[*x_map_ptr++]; + half h_6 = h_ptr[*x_map_ptr++]; + half h_7 = h_ptr[*x_map_ptr++]; + + half2 h_01 = __halves2half2(h_0, h_1); + half2 h_23 = __halves2half2(h_2, h_3); + half2 h_45 = __halves2half2(h_4, h_5); + half2 h_67 = __halves2half2(h_6, h_7); + + half2 tmp = __hmul2(h_01, v_01); + tmp = __hfma2(h_23, v_23, tmp); + tmp = __hfma2(h_45, v_45, tmp); + tmp = __hfma2(h_67, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_x_map_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); + tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu new file mode 100644 index 000000000000..f47daeb0e877 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu @@ -0,0 +1,260 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "q4_matmul.cuh" +#include "column_remap.cuh" +#include "util.cuh" +#include "matrix.cuh" +#include "cu_compat.cuh" +#include "cuda_buffers.cuh" +#if defined(USE_ROCM) +#include "hip_compat.cuh" +#endif + +const int THREADS_X = 32; // Block size and thread count along columns in w and out +const int THREADS_Y = 1; // Block size and thread count along rows in x and out + +typedef void (*fp_q4_matmul_kernel) +( + const half*, + const uint32_t*, + half*, + const half*, + const uint32_t*, + const int, + const int, + const int, + const int, + const int, + const uint32_t*, + bool +); + +template +__global__ void q4_matmul_kernel +( + const half* __restrict__ x, + const uint32_t* __restrict__ w, + half* __restrict__ out, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int dim, + const int width, + const int groupsize, + const int block_size_z, + const uint32_t* __restrict__ x_map, + bool no_zero +) +{ + // Start of block + + int x_column = block_size_z * blockIdx.z; + int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); + + int w_column = THREADS_X * blockIdx.x + threadIdx.x; + int x_row = THREADS_Y * blockIdx.y + threadIdx.y; + + int iterations = (x_column_end - x_column) / 8; + + // Views + + MatrixView_half x_(x, height, dim); + MatrixView_half w_scales_(w_scales, dim / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width); + MatrixView_q4_column w_(w, dim, width); + MatrixView_half_rw out_(out, height, width); + + // Zero output + + if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) + { + *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; + __syncthreads(); + } + + // Loop over part of x row (and w column) + + half2 acc = {}; + half acc_h = {}; + + if constexpr (use_groupsize) + { + // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this + // could be slightly faster + + for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) + { + if constexpr (use_half2) + { + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + else + { + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + } + } + else + { + // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache + + for (int k = x_column; k < x_column + iterations * 8; k += 8) + { + if constexpr (use_half2) + { + int group = k / groupsize; + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + else + { + int group = k / groupsize; + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + } + } + + // Add to block result + + if constexpr (use_half2) + { + half result = __hadd(__low2half(acc), __high2half(acc)); + atomicAdd(out_.item_ptr(x_row, w_column), result); + } + else + { + atomicAdd(out_.item_ptr(x_row, w_column), acc_h); + } +} + +fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map) +{ + // + if (tuningParams->matmul_no_half2) { + if (block_size_z % groupsize == 0) { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } else { + if (block_size_z % groupsize == 0) + { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } +}; + +// Compute y = x @ w + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero, + cudaStream_t alt_stream +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + + uint32_t* x_map = w->cuda_x_map; + const half* x_mapped = x; + if (x_map && !tuningParams->matmul_fused_remap && !alt_stream) + { + CudaBuffers* buffers = get_buffers(w->device); + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + x_map = NULL; + } + + int block_size_z; + if (w->width == 4096) block_size_z = 384; // 7B + else if (w->width == 11008) block_size_z = 256; + else if (w->width == 5120) block_size_z = 384; // 13B + else if (w->width == 13824) block_size_z = 256; + else if (w->width == 6656) block_size_z = 256; // 33B + else if (w->width == 17920) block_size_z = 128; + else block_size_z = 256; + + //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half)); + + dim3 threads(THREADS_X, THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height + threads.y - 1) / threads.y, + (dim + block_size_z - 1) / block_size_z + ); + + fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); + + kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); +} + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + const cublasHandle_t handle, + bool no_zero +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + CudaBuffers* buffers = get_buffers(w->device); + + const half* x_mapped = x; + if (w->cuda_x_map) + { + TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small"); + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + } + + w->reconstruct(buffers->temp_dq); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 + const float alpha = 1.0f; + const float beta = no_zero ? 1.0f : 0.0f; + cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, + x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); +#else + const half alpha = __float2half(1.0f); + const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); + cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); +#endif +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh new file mode 100644 index 000000000000..09f3e1a63362 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh @@ -0,0 +1,43 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matmul_cuh +#define _q4_matmul_cuh + +#include +#include +#include +#include +#include + +#include "q4_matrix.cuh" +#include "tuning.h" + +// Workaround for hipify_python using rocblas instead of hipblas. +#if defined(USE_ROCM) +#include +#define rocblas_handle hipblasHandle_t +#endif + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero = false, + cudaStream_t alt_stream = NULL +); + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + const cublasHandle_t handle, + bool no_zero = false +); + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu new file mode 100644 index 000000000000..9c61143f565e --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu @@ -0,0 +1,225 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "q4_matrix.cuh" +#include +#include "util.cuh" +#include "matrix.cuh" + +using namespace std; + +const int UNSHUF_BLOCKSIZE_X = 64; + +const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column +const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows + +vector g_q4_matrices; + +void g_q4_keep_matrix(Q4Matrix* m) +{ + g_q4_matrices.push_back(m); +} + +void g_q4_free_matrices() +{ + for (const auto& m : g_q4_matrices) delete m; + g_q4_matrices.clear(); +} + +Q4Matrix::Q4Matrix +( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device +) : + height(_height), + width(_width), + groups(_groups), + device(_device) +{ + cudaSetDevice(device); + + cuda_qweight = _qweight; + cuda_qzeros = _qzeros; + cuda_scales = _scales; + + groupsize = height / groups; + + if (_g_idx) make_sequential(_g_idx); +} + +Q4Matrix::~Q4Matrix() +{ +} + +// Make sequential + +__global__ void make_sequential_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const uint32_t* __restrict__ x_map, + const int w_height, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + + int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + + int w_new2_row = blockIdx.y; + + int x_map_idx = w_new2_row << 3; + + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + int source_row = x_map[x_map_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) +{ + uint32_t* cuda_new_qweight = NULL; + cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); + cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch + + uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); + uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); + uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); + + // Group histogram + + for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; + + // Group map + + for (int i = 0, acc = 0; i < groups; i++) + { + short tmp = cpu_g_idx_map[i]; + cpu_g_idx_map[i] = acc; + acc += tmp; + } + + // X map (inverse) + + for (int row = 0; row < height; row++) + { + uint32_t target_group = cpu_g_idx[row]; + uint32_t target_row = cpu_g_idx_map[target_group]; + cpu_g_idx_map[target_group]++; + cpu_x_map_inv[row] = target_row; + } + + // X map + + for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; + + // Move to CUDA + + cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice); + + // Rearrange rows in w + + dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); + dim3 blocks + ( + (width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2), + height / 8, + 1 + ); + + make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); + + // Replace qweights + + cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + + // Cleanup + + cudaDeviceSynchronize(); + cudaFree(cuda_new_qweight); + free(cpu_g_idx_map); + free(cpu_x_map); + free(cpu_x_map_inv); +} + +__global__ void reconstruct_kernel +( + const uint32_t* __restrict__ w, + half* __restrict__ out, // (y) + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int width, + const int groupsize +) +{ + // Start of block + + int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; + int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; + if (column >= width) return; + + // Views + + MatrixView_q4_column w_(w, height, width); + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, height / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width); + + // Groupsize version + + int group = row / groupsize; + + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + + uint32_t w_read = w_.item_uint32_t(row, column); + half* out_ptr = out_.item_ptr(row, column); + + #pragma unroll + for (int s = 0; s < 32; s += 4) + { + half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); + *out_ptr = w_item; out_ptr += out_.width; + } +} + +void Q4Matrix::reconstruct(half* out) +{ + dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height / 8 + threads.y - 1) / threads.y, + 1 + ); + + reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); +} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh new file mode 100644 index 000000000000..50cb72a41518 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh @@ -0,0 +1,53 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matrix_cuh +#define _q4_matrix_cuh + +#include +#include +#include + +class Q4Matrix +{ +public: + + int device; + + int height; + int width; + int groups; + int groupsize; + + uint32_t* cuda_qweight = NULL; + uint32_t* cuda_qzeros = NULL; + half* cuda_scales = NULL; + uint32_t* cuda_x_map = NULL; + + Q4Matrix + ( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device + ); + + ~Q4Matrix(); + + void reconstruct(half* out); + +private: + + void make_sequential(const uint32_t* cpu_g_idx); + +}; + +void g_q4_keep_matrix(Q4Matrix* m); +void g_q4_free_matrices(); + +#endif \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/gptq/tuning.h b/colossalai/kernel/cuda_native/csrc/gptq/tuning.h new file mode 100644 index 000000000000..770ca46aa7c8 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/tuning.h @@ -0,0 +1,13 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _tuning_h +#define _tuning_h + +struct ExLlamaTuning +{ + int matmul_recons_thd; + bool matmul_fused_remap; + bool matmul_no_half2; +}; + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/util.cuh b/colossalai/kernel/cuda_native/csrc/gptq/util.cuh new file mode 100644 index 000000000000..7b397573214b --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/gptq/util.cuh @@ -0,0 +1,33 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _util_cuh +#define _util_cuh + +#include +#include +#include +#include + +#if defined(USE_ROCM) +#define cudaUnspecified hipErrorUnknown +#else +#define cudaUnspecified cudaErrorApiFailureBase +#endif + +// React to failure on return code != cudaSuccess + +#define _cuda_check(fn) \ +do { \ + {_cuda_err = fn;} \ + if (_cuda_err != cudaSuccess) goto _cuda_fail; \ +} while(false) + +// React to failure on return code == 0 + +#define _alloc_check(fn) \ +do { \ + if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \ + else _cuda_err = cudaSuccess; \ +} while(false) + +#endif diff --git a/op_builder/gptq.py b/op_builder/gptq.py new file mode 100644 index 000000000000..012cf0f8a78d --- /dev/null +++ b/op_builder/gptq.py @@ -0,0 +1,52 @@ +import os +import torch +import re + +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag + +class GPTQBuilder(Builder): + + NAME = "cu_gptq" + PREBUILT_IMPORT_PATH = "colossalai._C.cu_gptq" + + def __init__(self): + super().__init__(name=GPTQBuilder.NAME, + prebuilt_import_path=GPTQBuilder.PREBUILT_IMPORT_PATH) + + + def include_dirs(self): + ret = [self.csrc_abs_path("gptq"), self.get_cuda_home_include()] + return ret + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) for fname in [ + 'gptq/linear_gptq.cpp', + 'gptq/column_remap.cu', + 'gptq/cuda_buffers.cu', + 'gptq/q4_matmul.cu', + 'gptq/q4_matrix.cu' + ] + ] + return ret + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ['-v', + '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK', "-lcublas", "-std=c++17" + ] + + + for arch in torch.cuda.get_arch_list(): + res = re.search(r'sm_(\d+)', arch) + if res: + arch_cap = res[1] + if int(arch_cap) >= 80: + extra_cuda_flags.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) + + ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) \ No newline at end of file diff --git a/tests/test_gptq/test_linear_act_fusion.py b/tests/test_gptq/test_gptq_linear.py similarity index 80% rename from tests/test_gptq/test_linear_act_fusion.py rename to tests/test_gptq/test_gptq_linear.py index 4540d990dc3a..eb42e3613768 100644 --- a/tests/test_gptq/test_linear_act_fusion.py +++ b/tests/test_gptq/test_gptq_linear.py @@ -4,11 +4,11 @@ import time import transformers from auto_gptq.quantization import GPTQ -from auto_gptq.modeling._utils import find_layers, pack_model +from auto_gptq.modeling._utils import find_layers, pack_model,autogptq_post_init from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear from auto_gptq.quantization.quantizer import Quantizer -from colossalai.gptq import CaiGPTQLinearOp +from colossalai.gptq import CaiGPTQLinearOp, CaiQuantLinear import math import numpy as np @@ -119,17 +119,17 @@ def cai_linear_pack(linear, scales, zeros, out_qscales.data.copy_(scales) - wn = 16 - pbits = 64 - ptype = torch.int64 - unsign_type = np.uint64 - sign_type = np.int64 + # wn = 16 + # pbits = 64 + # ptype = torch.int64 + # unsign_type = np.uint64 + # sign_type = np.int64 - # wn = 8 - # pbits = 32 - # ptype = torch.int32 - # unsign_type = np.uint32 - # sign_type = np.int32 + wn = 8 + pbits = 32 + ptype = torch.int32 + unsign_type = np.uint32 + sign_type = np.int32 intweight = [] for idx in range(infeatures): @@ -178,6 +178,16 @@ def cai_linear_pack(linear, scales, zeros, return out_qweight, out_qscales, out_qzeros + +def get_model_param(model, quantizers): + layers = find_layers(model) + layers = {n: layers[n] for n in quantizers} + with torch.no_grad(): + for name in layers: + _, scale, zero, g_idx = quantizers[name] + + return scale, zero, g_idx + def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize): layers = find_layers(model) layers = {n: layers[n] for n in quantizers} @@ -199,11 +209,11 @@ def test_gptq_linear(): weight = torch.randn(outfeature, infeature).to(torch.float16).to(torch.cuda.current_device()) bias = torch.zeros(outfeature).to(torch.float16).to(torch.cuda.current_device()) - wn = 16 - ptype = torch.int64 + # wn = 16 + # ptype = torch.int64 - # wn = 8 - # ptype = torch.int32 + wn = 8 + ptype = torch.int32 qweight = torch.zeros(infeature//wn, outfeature, dtype=ptype, device=torch.cuda.current_device()).contiguous() qscales = torch.zeros(infeature//groupsize, outfeature, dtype=torch.float16, device=torch.cuda.current_device()).contiguous() @@ -211,7 +221,7 @@ def test_gptq_linear(): act_func = nn.SiLU() inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device()) - batch_inps = torch.randn(1, 4096, infeature).to(torch.float16).to(torch.cuda.current_device()) + batch_inps = torch.randn(1, 2048, infeature).to(torch.float16).to(torch.cuda.current_device()) linear = MLinear(infeature, outfeature) linear.to(torch.cuda.current_device()) @@ -223,66 +233,41 @@ def test_gptq_linear(): with torch.no_grad(): torch_out = linear(inps) batch_torch_out = linear(batch_inps) - torch_out = act_func(torch_out) - batch_torch_out = act_func(batch_torch_out) + # torch_out = act_func(torch_out) + # batch_torch_out = act_func(batch_torch_out) # linear.to("cuda") quantizers = model_quant(linear, inps, torch.cuda.current_device()) - qweight, qscales, qzeros = model_cai_pack(linear, quantizers, qweight, qscales, qzeros, wbits, groupsize) - gptq_model = model_pack(linear, quantizers, wbits, groupsize) - gptq_model.to(torch.cuda.current_device()) - # gptq_model = linear + # qweight, qscales, qzeros = model_cai_pack(linear, quantizers, qweight, qscales, qzeros, wbits, groupsize) + scale, zero, g_idx = get_model_param(linear, quantizers) + cai_linear = CaiQuantLinear(wbits, groupsize, infeature, outfeature, True) - cai_linear = CaiGPTQLinearOp(groupsize, wbits) + cai_linear.to("cuda") + cai_linear.pack(linear.linear, scale, zero, g_idx) + cai_linear.to("cuda") + gptq_model = model_pack(linear, quantizers, wbits, groupsize) + gptq_model.to(torch.cuda.current_device()) + gptq_model=autogptq_post_init(gptq_model, False) - # qweight = torch.cat((qweight, qweight, qweight), dim=0).contiguous() - # qscales = torch.cat((qscales, qscales, qscales), dim=0).contiguous() - # qzeros = torch.cat((qzeros, qzeros, qzeros), dim=0).contiguous() - # bias = torch.cat((bias, bias, bias), dim=0).contiguous() qkv_fused=False with torch.no_grad(): gptq_out = gptq_model(inps) batch_gptq_out = gptq_model(batch_inps) - cai_out = cai_linear(inps, - qweight, - qscales, - qzeros, - bias = bias, - act_type = 3, - qkv_fused=qkv_fused) torch.cuda.synchronize() - - batch_cai_out = cai_linear(batch_inps, - qweight, - qscales, - qzeros, - bias=bias, - act_type = 3, - qkv_fused=qkv_fused) + cai_out = cai_linear(inps) torch.cuda.synchronize() - batch_gptq_out = act_func(batch_gptq_out) - gptq_out = act_func(gptq_out) - - # cai_out = cai_out[1] - # batch_cai_out = batch_cai_out[1] - # a = torch.sum(qscales, 0) - # print("qscales ", a) - # print("orch out ", torch_out) - # print("gptq out ", gptq_out) - # print("cai out ", cai_out) - # # print("batch_torch out ", batch_torch_out) - # print("batch_torch out ", batch_torch_out) - # print("batch_gptq out ", batch_gptq_out) - # print("batch_cai out ", batch_cai_out) - - assert torch.allclose(cai_out, gptq_out, rtol=1e-01, atol=1e-02) - assert torch.allclose(batch_cai_out, batch_gptq_out, rtol=1e-01, atol=1e-02) + batch_cai_out = cai_linear(batch_inps) + torch.cuda.synchronize() + # batch_gptq_out = act_func(batch_gptq_out) + # gptq_out = act_func(gptq_out) + assert torch.allclose(cai_out, gptq_out, rtol=1e-01, atol=1e-01) + assert torch.allclose(batch_cai_out, batch_gptq_out, rtol=1e-01, atol=1e-01) # mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) # max_diff = torch.max(torch.abs(cai_out - gptq_out)) From ae92d84818a7c885d615727fde4610793dc073be Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 23 Aug 2023 09:40:14 +0800 Subject: [PATCH 2/5] add license --- LICENSE | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/LICENSE b/LICENSE index c7a5bb16880e..0db47bd8986f 100644 --- a/LICENSE +++ b/LICENSE @@ -396,3 +396,53 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------- LICENSE FOR AutoGPTQ ---------------- + + From AutoGPTQ: + + MIT License + + Copyright (c) 2023 潘其威(William) + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + ---------------- LICENSE FOR exllama ---------------- + + From exllama: + + MIT License + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. From 5e565ba76635f54f6fff7db926e7421f34e76784 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 23 Aug 2023 11:57:53 +0800 Subject: [PATCH 3/5] [skip ci] fix max_input_len --- colossalai/gptq/cai_gptq/cai_quant_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/gptq/cai_gptq/cai_quant_linear.py b/colossalai/gptq/cai_gptq/cai_quant_linear.py index 16285dc17e29..b42aa09c1dcd 100644 --- a/colossalai/gptq/cai_gptq/cai_quant_linear.py +++ b/colossalai/gptq/cai_gptq/cai_quant_linear.py @@ -131,7 +131,7 @@ def pack(self, linear, scales, zeros, g_idx=None): if self.g_idx is not None: CaiQuantLinear.max_inner_outer_dim = max(CaiQuantLinear.max_inner_outer_dim, self.infeatures, self.outfeatures) - max_input_len=4096 + CaiQuantLinear.max_input_len=4096 def prepare_buffers(self): From 98b1a1d99a360863d45b4926a1e79698eb19c515 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 23 Aug 2023 15:52:59 +0800 Subject: [PATCH 4/5] format files & change test size --- colossalai/gptq/__init__.py | 7 +- colossalai/gptq/cai_gptq/__init__.py | 4 +- colossalai/gptq/cai_gptq/cai_quant_linear.py | 95 +++--- colossalai/gptq/cai_gptq/gptq_op.py | 28 +- colossalai/gptq/cai_gptq/gptq_triton.py | 294 ++++++++++++------- tests/test_gptq/test_gptq_linear.py | 130 ++++---- 6 files changed, 341 insertions(+), 217 deletions(-) diff --git a/colossalai/gptq/__init__.py b/colossalai/gptq/__init__.py index 0e0ee5152138..59b87d6ca692 100644 --- a/colossalai/gptq/__init__.py +++ b/colossalai/gptq/__init__.py @@ -1,7 +1,4 @@ from .cai_gptq import HAS_AUTO_GPTQ -if HAS_AUTO_GPTQ: - from .cai_gptq import (gptq_fused_linear_triton, make_cai_quant_linear, - CaiQuantLinear, CaiGPTQLinearOp) - - +if HAS_AUTO_GPTQ: + from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear, gptq_fused_linear_triton, make_cai_quant_linear diff --git a/colossalai/gptq/cai_gptq/__init__.py b/colossalai/gptq/cai_gptq/__init__.py index 68addb8fb2f5..fcdef7734438 100644 --- a/colossalai/gptq/cai_gptq/__init__.py +++ b/colossalai/gptq/cai_gptq/__init__.py @@ -9,6 +9,6 @@ HAS_AUTO_GPTQ = False if HAS_AUTO_GPTQ: - from .gptq_triton import gptq_fused_linear_triton - from .cai_quant_linear import make_cai_quant_linear, CaiQuantLinear + from .cai_quant_linear import CaiQuantLinear, make_cai_quant_linear from .gptq_op import CaiGPTQLinearOp + from .gptq_triton import gptq_fused_linear_triton diff --git a/colossalai/gptq/cai_gptq/cai_quant_linear.py b/colossalai/gptq/cai_gptq/cai_quant_linear.py index b42aa09c1dcd..c65b325d54ee 100644 --- a/colossalai/gptq/cai_gptq/cai_quant_linear.py +++ b/colossalai/gptq/cai_gptq/cai_quant_linear.py @@ -1,12 +1,14 @@ # Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ import math +import warnings + import numpy as np import torch import torch.nn as nn -from .gptq_op import CaiGPTQLinearOp import triton -import warnings + +from .gptq_op import CaiGPTQLinearOp HAS_GPTQ_CUDA = False try: @@ -16,17 +18,18 @@ except ImportError: warnings.warn('CUDA gptq is not installed') HAS_GPTQ_CUDA = False - + class CaiQuantLinear(nn.Module): - max_dq_buffer_size=1 - max_inner_outer_dim=1 - max_input_len=1 - prepared_buffers=False + max_dq_buffer_size = 1 + max_inner_outer_dim = 1 + max_input_len = 1 + prepared_buffers = False device_to_buffers = { "temp_state": None, "temp_dq": None, } + def __init__(self, bits, groupsize, infeatures, outfeatures, bias): super().__init__() if bits not in [2, 4, 8]: @@ -38,8 +41,11 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias): self.groupsize = groupsize if groupsize != -1 else infeatures self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) - self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) - self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + self.register_buffer( + 'qzeros', + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) + self.register_buffer('scales', + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) if bias: @@ -54,7 +60,8 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias): def pack(self, linear, scales, zeros, g_idx=None): - g_idx = g_idx.clone() if g_idx is not None else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) + g_idx = g_idx.clone() if g_idx is not None else torch.tensor( + [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) scales = scales.t().contiguous() zeros = zeros.t().contiguous() @@ -79,7 +86,10 @@ def pack(self, linear, scales, zeros, g_idx=None): intweight = [] for idx in range(self.infeatures): - intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, None]) + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, + None]) intweight = torch.cat(intweight, dim=1) intweight = intweight.t().contiguous() intweight = intweight.numpy().astype(unsign_type) @@ -93,27 +103,27 @@ def pack(self, linear, scales, zeros, g_idx=None): while row < qweight.shape[0]: if self.bits in [2, 4, 8]: - for j in range(i, i + (pbits // self.bits)): - qweight[row] |= intweight[j] << ( self.bits * (j - i)) - i += pbits // self.bits + for j in range(i, i + (pbits // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += pbits // self.bits row += 1 else: raise NotImplementedError("Only 2,4,8 bits are supported.") qweight = qweight.astype(sign_type) qweight1 = torch.from_numpy(qweight) - qweight1 = qweight1.contiguous() #.to("cuda") + qweight1 = qweight1.contiguous() #.to("cuda") self.qweight.data.copy_(qweight1) - qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type) zeros -= 1 zeros = zeros.numpy().astype(unsign_type) i = 0 col = 0 while col < qzeros.shape[1]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (pbits // self.bits)): - qzeros[:, col] |= zeros[:, j] << ( self.bits * (j - i)) - i += pbits // self.bits + if self.bits in [2, 4, 8]: + for j in range(i, i + (pbits // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += pbits // self.bits col += 1 else: raise NotImplementedError("Only 2,4,8 bits are supported.") @@ -121,7 +131,7 @@ def pack(self, linear, scales, zeros, g_idx=None): qzeros = torch.from_numpy(qzeros) qzeros = qzeros self.qzeros.data.copy_(qzeros) - + if torch.equal(self.g_idx.to(g_idx.device), g_idx): self.g_idx = None else: @@ -130,9 +140,9 @@ def pack(self, linear, scales, zeros, g_idx=None): CaiQuantLinear.max_dq_buffer_size = max(CaiQuantLinear.max_dq_buffer_size, self.qweight.numel() * 8) if self.g_idx is not None: - CaiQuantLinear.max_inner_outer_dim = max(CaiQuantLinear.max_inner_outer_dim, self.infeatures, self.outfeatures) - CaiQuantLinear.max_input_len=4096 - + CaiQuantLinear.max_inner_outer_dim = max(CaiQuantLinear.max_inner_outer_dim, self.infeatures, + self.outfeatures) + CaiQuantLinear.max_input_len = 4096 def prepare_buffers(self): assert self.qweight.device.type == "cuda" @@ -140,10 +150,14 @@ def prepare_buffers(self): # The temp_state buffer is required to reorder X in the act-order case. # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - CaiQuantLinear.device_to_buffers['temp_state'] = torch.zeros((CaiQuantLinear.max_input_len, CaiQuantLinear.max_inner_outer_dim), dtype=torch.float16, device=device) - CaiQuantLinear.device_to_buffers['temp_dp'] = torch.zeros((1, CaiQuantLinear.max_dq_buffer_size), dtype=torch.float16, device=device) + CaiQuantLinear.device_to_buffers['temp_state'] = torch.zeros( + (CaiQuantLinear.max_input_len, CaiQuantLinear.max_inner_outer_dim), dtype=torch.float16, device=device) + CaiQuantLinear.device_to_buffers['temp_dp'] = torch.zeros((1, CaiQuantLinear.max_dq_buffer_size), + dtype=torch.float16, + device=device) - gptq_cuda.prepare_buffers(torch.device(device), CaiQuantLinear.device_to_buffers['temp_state'], CaiQuantLinear.device_to_buffers['temp_dp']) + gptq_cuda.prepare_buffers(torch.device(device), CaiQuantLinear.device_to_buffers['temp_state'], + CaiQuantLinear.device_to_buffers['temp_dp']) # Using the default from exllama repo here. matmul_recons_thd = 8 @@ -152,6 +166,7 @@ def prepare_buffers(self): gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) torch.cuda.empty_cache() + def init_q4(self): assert self.qweight.device.type == "cuda" self.q4_width = self.qweight.shape[1] @@ -160,11 +175,7 @@ def init_q4(self): else: g_idx = self.empty_tensor - self.q4 = gptq_cuda.make_q4(self.qweight, - self.qzeros, - self.scales, - g_idx, - torch.cuda.current_device()) + self.q4 = gptq_cuda.make_q4(self.qweight, self.qzeros, self.scales, g_idx, torch.cuda.current_device()) torch.cuda.synchronize() def forward(self, x): @@ -174,7 +185,7 @@ def forward(self, x): if CaiQuantLinear.prepared_buffers == False: self.prepare_buffers() CaiQuantLinear.prepared_buffers = True - + if self.q4 is None: self.init_q4() @@ -184,14 +195,17 @@ def forward(self, x): if self.bias is not None: output.add_(self.bias) else: - output = self.gptq_linear(x, - self.qweight, - self.scales, - self.qzeros, - g_idx = self.g_idx, - bias = self.bias,) + output = self.gptq_linear( + x, + self.qweight, + self.scales, + self.qzeros, + g_idx=self.g_idx, + bias=self.bias, + ) return output.view(outshape) + def make_cai_quant_linear(module, names, bits, groupsize, name=''): if isinstance(module, CaiQuantLinear): return @@ -200,6 +214,7 @@ def make_cai_quant_linear(module, names, bits, groupsize, name=''): name1 = name + '.' + attr if name != '' else attr if name1 in names: delattr(module, attr) - setattr(module, attr, CaiQuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None)) + setattr(module, attr, + CaiQuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None)) for name1, child in module.named_children(): make_cai_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) diff --git a/colossalai/gptq/cai_gptq/gptq_op.py b/colossalai/gptq/cai_gptq/gptq_op.py index aca1cb5b87c5..32cbab743228 100644 --- a/colossalai/gptq/cai_gptq/gptq_op.py +++ b/colossalai/gptq/cai_gptq/gptq_op.py @@ -1,6 +1,7 @@ -from .gptq_triton import gptq_fused_linear_triton import torch +from .gptq_triton import gptq_fused_linear_triton + class CaiGPTQLinearOp(torch.nn.Module): @@ -17,10 +18,10 @@ def forward(self, weight_scales: torch.Tensor, weight_zeros: torch.Tensor, g_idx: torch.Tensor = None, - act_type = 0, + act_type=0, bias: torch.Tensor = None, - residual: torch.Tensor=None, - qkv_fused = False): + residual: torch.Tensor = None, + qkv_fused=False): add_bias = True if bias is None: @@ -33,12 +34,23 @@ def forward(self, add_residual = False x = input.view(-1, input.shape[-1]) - out = gptq_fused_linear_triton(x, weight, weight_scales, weight_zeros, bias, residual, - self.bits, self.maxq, self.group_size, qkv_fused, add_bias, add_residual, - act_type=act_type, g_idx=g_idx) + out = gptq_fused_linear_triton(x, + weight, + weight_scales, + weight_zeros, + bias, + residual, + self.bits, + self.maxq, + self.group_size, + qkv_fused, + add_bias, + add_residual, + act_type=act_type, + g_idx=g_idx) if qkv_fused: out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1]) else: out = out.view(input.shape[0], input.shape[1], weight.shape[-1]) - return out \ No newline at end of file + return out diff --git a/colossalai/gptq/cai_gptq/gptq_triton.py b/colossalai/gptq/cai_gptq/gptq_triton.py index def711d1e6c4..231483258f18 100644 --- a/colossalai/gptq/cai_gptq/gptq_triton.py +++ b/colossalai/gptq/cai_gptq/gptq_triton.py @@ -1,15 +1,17 @@ +import torch import triton import triton.language as tl -import torch from auto_gptq.nn_modules.triton_utils import custom_autotune + # from ..ops.triton.kernels.activations_kernels import relu, gelu, silu # code based https://github.com/fpgaminer/GPTQ-triton - # triton.Config({ - # 'BLOCK_SIZE_M': 32, - # 'BLOCK_SIZE_N': 32, - # 'BLOCK_SIZE_K': 128, - # 'GROUP_SIZE_M': 8 - # }, num_stages=2, num_warps=4), +# triton.Config({ +# 'BLOCK_SIZE_M': 32, +# 'BLOCK_SIZE_N': 32, +# 'BLOCK_SIZE_K': 128, +# 'GROUP_SIZE_M': 8 +# }, num_stages=2, num_warps=4), + @triton.jit def tanh(x): @@ -91,13 +93,12 @@ def smelu(x): beta = 2.0 relu = tl.where(x >= beta, x, 0.0) - return tl.where( - tl.abs(x) <= beta, (x + beta) * (x + beta) / (4.0 * beta), relu) + return tl.where(tl.abs(x) <= beta, (x + beta) * (x + beta) / (4.0 * beta), relu) @triton.jit def silu(x): - return x*tl.sigmoid(x) + return x * tl.sigmoid(x) @custom_autotune.autotune( @@ -107,49 +108,65 @@ def silu(x): 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), + }, + num_stages=4, + num_warps=4), triton.Config({ 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), + }, + num_stages=4, + num_warps=4), triton.Config({ 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), + }, + num_stages=4, + num_warps=4), triton.Config({ 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), + }, + num_stages=4, + num_warps=4), triton.Config({ 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), + }, + num_stages=4, + num_warps=4), triton.Config({ 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 - }, num_stages=2, num_warps=8), + }, + num_stages=2, + num_warps=8), triton.Config({ 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8 - }, num_stages=3, num_warps=8), + }, + num_stages=3, + num_warps=8), triton.Config({ 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8 - }, num_stages=2, num_warps=4), + }, + num_stages=2, + num_warps=4), ], key=['M', 'N', 'K'], nearest_power_of_two=True, @@ -160,11 +177,11 @@ def silu(x): }, ) @triton.jit -def cai_gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_ptr, residual_ptr, - M, N, K, bits, maxq, gptq_group_size, - stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, - QKV_FUSED: tl.constexpr, ADD_BIAS: tl.constexpr, ADD_RESIDUAL:tl.constexpr, ACT_TYPE:tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): +def cai_gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_ptr, residual_ptr, M, N, K, bits, maxq, + gptq_group_size, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, + stride_scales, stride_zeros, QKV_FUSED: tl.constexpr, ADD_BIAS: tl.constexpr, + ADD_RESIDUAL: tl.constexpr, ACT_TYPE: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): """ Compute the matrix multiplication C = A x B. A is of shape (M, K) float16 @@ -181,7 +198,7 @@ def cai_gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K) - qkv_offset = pid // (num_pid_m * num_pid_n) + qkv_offset = pid // (num_pid_m * num_pid_n) pid = pid % (num_pid_m * num_pid_n) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group @@ -190,20 +207,22 @@ def cai_gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_ pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) # offs_bk = offs_k + qkv_offset * NK - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) a_mask = (offs_am[:, None] < M) # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + qkv_offset * N * NK //infearure_per_bits + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + b_ptrs = b_ptr + qkv_offset * N * NK // infearure_per_bits + ( + (offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) # g_ptrs = g_ptr + offs_k # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + qkv_offset * NK * N //gptq_group_size + offs_bn[None, :] - zeros_ptrs = zeros_ptr + qkv_offset * NK * N //gptq_group_size//infearure_per_bits + (offs_bn[None, :] // infearure_per_bits) + scales_ptrs = scales_ptr + qkv_offset * NK * N // gptq_group_size + offs_bn[None, :] + zeros_ptrs = zeros_ptr + qkv_offset * NK * N // gptq_group_size // infearure_per_bits + (offs_bn[None, :] // + infearure_per_bits) shifter = (offs_k % infearure_per_bits) * bits zeros_shifter = (offs_bn % infearure_per_bits) * bits @@ -214,24 +233,24 @@ def cai_gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_ # tl.device_print("gidx, ", g_idx) currend_group_end = gptq_group_size - scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = (zeros >> zeros_shifter[None, :]) & maxq zeros = (zeros + 1) for k in range(0, num_pid_k): # g_idx = tl.load(g_ptrs) # if (k + 1) * BLOCK_SIZE_K > currend_group_end: - scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = (zeros >> zeros_shifter[None, :]) & maxq zeros = (zeros + 1) # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros).to(tl.float16) * scales # Scale and shift + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros).to(tl.float16) * scales # Scale and shift accumulator += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K @@ -239,29 +258,27 @@ def cai_gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_ g_idx = g_idx_base + ((k + 1) * BLOCK_SIZE_K) // gptq_group_size # if (k + 2) * BLOCK_SIZE_K > currend_group_end: - c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) if ADD_BIAS: - bias_mask = (offs_bn < N) + bias_mask = (offs_bn < N) offs_bn += qkv_offset * N bias_ptrs = bias_ptr + stride_cn * offs_bn - bias = tl.load(bias_ptrs, mask=bias_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + bias = tl.load(bias_ptrs, mask=bias_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) accumulator += bias[None, :] - if ACT_TYPE == 1: - accumulator=relu(accumulator) + accumulator = relu(accumulator) elif ACT_TYPE == 2: - accumulator=gelu(accumulator) + accumulator = gelu(accumulator) elif ACT_TYPE == 3: - accumulator=silu(accumulator) - + accumulator = silu(accumulator) if ADD_RESIDUAL: - residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] res = tl.load(residual_ptrs, mask=c_mask, other=0.) - accumulator += res + accumulator += res tl.store(c_ptrs, accumulator, mask=c_mask) @@ -273,49 +290,65 @@ def cai_gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_ 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), + }, + num_stages=4, + num_warps=4), triton.Config({ 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), + }, + num_stages=4, + num_warps=4), triton.Config({ 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), + }, + num_stages=4, + num_warps=4), triton.Config({ 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), + }, + num_stages=4, + num_warps=4), triton.Config({ 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), + }, + num_stages=4, + num_warps=4), triton.Config({ 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 - }, num_stages=2, num_warps=8), + }, + num_stages=2, + num_warps=8), triton.Config({ 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8 - }, num_stages=3, num_warps=8), + }, + num_stages=3, + num_warps=8), triton.Config({ 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8 - }, num_stages=2, num_warps=4), + }, + num_stages=2, + num_warps=4), ], key=['M', 'N', 'K'], nearest_power_of_two=True, @@ -326,11 +359,12 @@ def cai_gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_ }, ) @triton.jit -def cai_gptq_idx_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, idx_ptr, bias_ptr, residual_ptr, - M, N, K, bits, maxq, gptq_group_size, - stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, - QKV_FUSED: tl.constexpr, ADD_BIAS: tl.constexpr, ADD_RESIDUAL:tl.constexpr, ACT_TYPE:tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): +def cai_gptq_idx_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, idx_ptr, bias_ptr, residual_ptr, M, N, K, + bits, maxq, gptq_group_size, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, + stride_cn, stride_scales, stride_zeros, QKV_FUSED: tl.constexpr, + ADD_BIAS: tl.constexpr, ADD_RESIDUAL: tl.constexpr, ACT_TYPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): """ Compute the matrix multiplication C = A x B. A is of shape (M, K) float16 @@ -353,7 +387,7 @@ def cai_gptq_idx_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, i num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K) - qkv_offset = pid // (num_pid_m * num_pid_n) + qkv_offset = pid // (num_pid_m * num_pid_n) pid = pid % (num_pid_m * num_pid_n) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group @@ -362,20 +396,22 @@ def cai_gptq_idx_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, i pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) # offs_bk = offs_k + qkv_offset * NK - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) a_mask = (offs_am[:, None] < M) # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + qkv_offset * N * NK //infearure_per_bits + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + b_ptrs = b_ptr + qkv_offset * N * NK // infearure_per_bits + ( + (offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn + ) # (BLOCK_SIZE_K, BLOCK_SIZE_N) # g_ptrs = g_ptr + offs_k # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + qkv_offset * NK * N //gptq_group_size + offs_bn[None, :] - zeros_ptrs = zeros_ptr + qkv_offset * NK * N //gptq_group_size//infearure_per_bits + (offs_bn[None, :] // infearure_per_bits) + scales_ptrs = scales_ptr + qkv_offset * NK * N // gptq_group_size + offs_bn[None, :] + zeros_ptrs = zeros_ptr + qkv_offset * NK * N // gptq_group_size // infearure_per_bits + (offs_bn[None, :] // + infearure_per_bits) shifter = (offs_k % infearure_per_bits) * bits zeros_shifter = (offs_bn % infearure_per_bits) * bits @@ -386,58 +422,67 @@ def cai_gptq_idx_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, i zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) currend_group_end = gptq_group_size - scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) for k in range(0, num_pid_k): # g_idx = tl.load(g_ptrs) - scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = (zeros >> zeros_shifter[None, :]) & maxq zeros = (zeros + 1) # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros).to(tl.float16) * scales # Scale and shift + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros).to(tl.float16) * scales # Scale and shift accumulator += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk g_ptrs += BLOCK_SIZE_K - c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) if ADD_BIAS: - bias_mask = (offs_bn < N) + bias_mask = (offs_bn < N) offs_bn += qkv_offset * N bias_ptrs = bias_ptr + stride_cn * offs_bn - bias = tl.load(bias_ptrs, mask=bias_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + bias = tl.load(bias_ptrs, mask=bias_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) accumulator += bias[None, :] - if ACT_TYPE == 1: - accumulator=relu(accumulator) + accumulator = relu(accumulator) elif ACT_TYPE == 2: - accumulator=gelu(accumulator) + accumulator = gelu(accumulator) elif ACT_TYPE == 3: - accumulator=silu(accumulator) - + accumulator = silu(accumulator) if ADD_RESIDUAL: - residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] res = tl.load(residual_ptrs, mask=c_mask, other=0.) - accumulator += res + accumulator += res tl.store(c_ptrs, accumulator, mask=c_mask) -def gptq_fused_linear_triton(input, qweight, scales, qzeros, bias, residual, - bits, maxq, gptq_group_size, qkv_fused, add_bias, add_residual, g_idx = None, act_type = 0): +def gptq_fused_linear_triton(input, + qweight, + scales, + qzeros, + bias, + residual, + bits, + maxq, + gptq_group_size, + qkv_fused, + add_bias, + add_residual, + g_idx=None, + act_type=0): # print("gptq fused ", qkv_fused, add_bias, add_residual) assert input.is_cuda, "input is not in cuda" assert qweight.is_cuda, "qweight is not in cuda" @@ -446,27 +491,68 @@ def gptq_fused_linear_triton(input, qweight, scales, qzeros, bias, residual, with torch.cuda.device(input.device): if qkv_fused: - grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']) * 3, ) - output = torch.empty((input.shape[0]*3, qweight.shape[1]), device=input.device, dtype=torch.float16) + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv( + qweight.shape[1], META['BLOCK_SIZE_N']) * 3,) + output = torch.empty((input.shape[0] * 3, qweight.shape[1]), device=input.device, dtype=torch.float16) else: - grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), ) + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv( + qweight.shape[1], META['BLOCK_SIZE_N']),) output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) # print("dtype, ", qweight.dtype, output.dtype, scales.dtype, qzeros.dtype, bias.dtype, residual.dtype) if g_idx is None: - cai_gptq_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, bias, residual, - input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, - gptq_group_size, - input.stride(0), input.stride(1), qweight.stride(0), - qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0), - QKV_FUSED=qkv_fused, ADD_BIAS=add_bias, ADD_RESIDUAL=add_residual, ACT_TYPE=act_type) + cai_gptq_matmul_248_kernel[grid](input, + qweight, + output, + scales, + qzeros, + bias, + residual, + input.shape[0], + qweight.shape[1], + input.shape[1], + bits, + maxq, + gptq_group_size, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + QKV_FUSED=qkv_fused, + ADD_BIAS=add_bias, + ADD_RESIDUAL=add_residual, + ACT_TYPE=act_type) else: - cai_gptq_idx_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, bias, residual, - input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, - gptq_group_size, - input.stride(0), input.stride(1), qweight.stride(0), - qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0), - QKV_FUSED=qkv_fused, ADD_BIAS=add_bias, ADD_RESIDUAL=add_residual, ACT_TYPE=act_type) - if qkv_fused: + cai_gptq_idx_matmul_248_kernel[grid](input, + qweight, + output, + scales, + qzeros, + g_idx, + bias, + residual, + input.shape[0], + qweight.shape[1], + input.shape[1], + bits, + maxq, + gptq_group_size, + input.stride(0), + input.stride(1), + qweight.stride(0), + qweight.stride(1), + output.stride(0), + output.stride(1), + scales.stride(0), + qzeros.stride(0), + QKV_FUSED=qkv_fused, + ADD_BIAS=add_bias, + ADD_RESIDUAL=add_residual, + ACT_TYPE=act_type) + if qkv_fused: return output.view(3, input.shape[0], qweight.shape[1]) else: return output diff --git a/tests/test_gptq/test_gptq_linear.py b/tests/test_gptq/test_gptq_linear.py index eb42e3613768..bd663f03c115 100644 --- a/tests/test_gptq/test_gptq_linear.py +++ b/tests/test_gptq/test_gptq_linear.py @@ -1,33 +1,38 @@ +import math +import time + +import numpy as np +import pytest import torch import torch.nn as nn -import pytest -import time import transformers -from auto_gptq.quantization import GPTQ -from auto_gptq.modeling._utils import find_layers, pack_model,autogptq_post_init +from auto_gptq.modeling._utils import autogptq_post_init, find_layers, pack_model from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear - +from auto_gptq.quantization import GPTQ from auto_gptq.quantization.quantizer import Quantizer + from colossalai.gptq import CaiGPTQLinearOp, CaiQuantLinear -import math -import numpy as np + +wbits = 4 +trits = False +nsamples = 1 +percdamp = .01 +groupsize = 128 +act_order = False +sym = False -wbits=4 -trits=False -nsamples=1 -percdamp=.01 -groupsize=128 -act_order=False -sym=False class MLinear(nn.Module): + def __init__(self, infeature, outfeature): super(MLinear, self).__init__() self.linear = torch.nn.Linear(infeature, outfeature, dtype=torch.float16) + def forward(self, x): out = self.linear(x) return out - + + @torch.no_grad() def model_quant(model, inps, dev): print('Starting ...') @@ -36,14 +41,18 @@ def model_quant(model, inps, dev): dtype = next(iter(model.parameters())).dtype cache = {'i': 0} + class Catcher(nn.Module): + def __init__(self, module): super().__init__() self.module = module + def forward(self, inp, **kwargs): inps[cache['i']] = inp cache['i'] += 1 raise ValueError + layers[0] = Catcher(layers[0]) # for batch in inps: try: @@ -59,32 +68,34 @@ def forward(self, inp, **kwargs): quantizers = {} for i in range(len(layers)): layer = layers[i].to(dev) - subset = find_layers(layer) - gptq = {} - for name in subset: - gptq[name] = GPTQ(subset[name]) - # gptq[name].quantizer = Quantizer() - gptq[name].quantizer.configure(wbits, perchannel=True, sym=sym, mse=False, trits=trits) - - def add_batch(name): - def tmp(_, inp, out): - gptq[name].add_batch(inp[0].data, out.data) - return tmp - - handles = [] - for name in subset: - handles.append(subset[name].register_forward_hook(add_batch(name))) - - for j in range(nsamples): - outs[j] = layer(inps[j].unsqueeze(0))[0] - - for h in handles: - h.remove() - for name in subset: + subset = find_layers(layer) + gptq = {} + for name in subset: + gptq[name] = GPTQ(subset[name]) + # gptq[name].quantizer = Quantizer() + gptq[name].quantizer.configure(wbits, perchannel=True, sym=sym, mse=False, trits=trits) + + def add_batch(name): + + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + + return tmp + + handles = [] + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + + for j in range(nsamples): + outs[j] = layer(inps[j].unsqueeze(0))[0] + + for h in handles: + h.remove() + for name in subset: print(f'Quantizing {name} in layer {i+1}/{len(layers)}...') - scale,zero,g_idx = gptq[name].fasterquant(percdamp=percdamp, group_size=groupsize, actorder=act_order) + scale, zero, g_idx = gptq[name].fasterquant(percdamp=percdamp, group_size=groupsize, actorder=act_order) # quantizers['%s' % (name)] = (gptq[name].quantizer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) - quantizers['%s' % (name)] = (gptq[name].layer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) + quantizers['%s' % (name)] = (gptq[name].layer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) gptq[name].free() for j in range(nsamples): @@ -93,11 +104,11 @@ def tmp(_, inp, out): layers[i] = layer.cpu() del layer - del gptq + del gptq torch.cuda.empty_cache() inps, outs = outs, inps - + return quantizers @@ -106,10 +117,9 @@ def model_pack(model, quantizers, wbits, groupsize): return model -def cai_linear_pack(linear, scales, zeros, - out_qweight, out_qscales, out_qzeros, qg_idx, - infeatures, groupsize, bits): - g_idx = qg_idx.clone() if qg_idx is not None else torch.tensor([i // groupsize for i in range(infeatures)], dtype=torch.int32) +def cai_linear_pack(linear, scales, zeros, out_qweight, out_qscales, out_qzeros, qg_idx, infeatures, groupsize, bits): + g_idx = qg_idx.clone() if qg_idx is not None else torch.tensor([i // groupsize for i in range(infeatures)], + dtype=torch.int32) scales = scales.t().contiguous() zeros = zeros.t().contiguous() @@ -133,7 +143,9 @@ def cai_linear_pack(linear, scales, zeros, intweight = [] for idx in range(infeatures): - intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, None]) + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, None]) intweight = torch.cat(intweight, dim=1) intweight = intweight.t().contiguous() intweight = intweight.numpy().astype(unsign_type) @@ -188,18 +200,18 @@ def get_model_param(model, quantizers): return scale, zero, g_idx + def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize): layers = find_layers(model) layers = {n: layers[n] for n in quantizers} with torch.no_grad(): for name in layers: _, scale, zero, g_idx = quantizers[name] - qweight, qscales, qzeros = cai_linear_pack(layers[name], scale, zero, - qweight, qscales, qzeros, g_idx, - layers[name].weight.shape[-1], groupsize, wbits) + qweight, qscales, qzeros = cai_linear_pack(layers[name], scale, zero, qweight, qscales, qzeros, g_idx, + layers[name].weight.shape[-1], groupsize, wbits) # print("cai pack", layers) - return qweight, qscales, qzeros + return qweight, qscales, qzeros def test_gptq_linear(): @@ -211,17 +223,19 @@ def test_gptq_linear(): bias = torch.zeros(outfeature).to(torch.float16).to(torch.cuda.current_device()) # wn = 16 # ptype = torch.int64 - + wn = 8 ptype = torch.int32 - qweight = torch.zeros(infeature//wn, outfeature, dtype=ptype, device=torch.cuda.current_device()).contiguous() - qscales = torch.zeros(infeature//groupsize, outfeature, dtype=torch.float16, device=torch.cuda.current_device()).contiguous() - qzeros = torch.zeros(infeature//groupsize, outfeature//wn, dtype=ptype, device=torch.cuda.current_device()).contiguous() + qweight = torch.zeros(infeature // wn, outfeature, dtype=ptype, device=torch.cuda.current_device()).contiguous() + qscales = torch.zeros(infeature // groupsize, outfeature, dtype=torch.float16, + device=torch.cuda.current_device()).contiguous() + qzeros = torch.zeros(infeature // groupsize, outfeature // wn, dtype=ptype, + device=torch.cuda.current_device()).contiguous() act_func = nn.SiLU() inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device()) - batch_inps = torch.randn(1, 2048, infeature).to(torch.float16).to(torch.cuda.current_device()) + batch_inps = torch.randn(1, 4096, infeature).to(torch.float16).to(torch.cuda.current_device()) linear = MLinear(infeature, outfeature) linear.to(torch.cuda.current_device()) @@ -236,7 +250,6 @@ def test_gptq_linear(): # torch_out = act_func(torch_out) # batch_torch_out = act_func(batch_torch_out) - # linear.to("cuda") quantizers = model_quant(linear, inps, torch.cuda.current_device()) # qweight, qscales, qzeros = model_cai_pack(linear, quantizers, qweight, qscales, qzeros, wbits, groupsize) @@ -250,9 +263,9 @@ def test_gptq_linear(): gptq_model = model_pack(linear, quantizers, wbits, groupsize) gptq_model.to(torch.cuda.current_device()) - gptq_model=autogptq_post_init(gptq_model, False) + gptq_model = autogptq_post_init(gptq_model, False) - qkv_fused=False + qkv_fused = False with torch.no_grad(): gptq_out = gptq_model(inps) @@ -289,6 +302,7 @@ def test_gptq_linear(): # max_diff = torch.max(torch.abs(batch_torch_out - batch_cai_out)) # print("batch torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + if __name__ == "__main__": test_gptq_linear() From f216e0b53759f6d7ba811ba3fd338dcea2fa4815 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 23 Aug 2023 15:59:42 +0800 Subject: [PATCH 5/5] [skip ci] --- tests/test_gptq/test_gptq_linear.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_gptq/test_gptq_linear.py b/tests/test_gptq/test_gptq_linear.py index bd663f03c115..7b3913928587 100644 --- a/tests/test_gptq/test_gptq_linear.py +++ b/tests/test_gptq/test_gptq_linear.py @@ -265,7 +265,6 @@ def test_gptq_linear(): gptq_model.to(torch.cuda.current_device()) gptq_model = autogptq_post_init(gptq_model, False) - qkv_fused = False with torch.no_grad(): gptq_out = gptq_model(inps)