From 28ad33ff8b815fff77ba0d901ecaa89f1ff8c753 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Fri, 11 Aug 2023 15:48:15 +0800 Subject: [PATCH 01/15] add gptq --- colossalai/gptq/__init__.py | 5 + colossalai/gptq/cai_gptq/__init__.py | 2 + colossalai/gptq/cai_gptq/cai_quant_linear.py | 142 ++++ colossalai/gptq/cai_gptq/gptq_autotune.py | 167 ++++ colossalai/gptq/cai_gptq/gptq_op.py | 101 +++ colossalai/gptq/cai_gptq/gptq_triton.py | 781 ++++++++++++++++++ colossalai/gptq/config.py | 36 + colossalai/gptq/csrc/gptq_act_linear.cu | 387 +++++++++ .../gptq/csrc/includes/conversion_utils.h | 641 ++++++++++++++ .../gptq/csrc/includes/ds_kernel_utils.h | 52 ++ .../csrc/includes/inference_cuda_layers.h | 32 + colossalai/gptq/csrc/pt_binding.cpp | 23 + colossalai/gptq/gptq_utils/__init__.py | 1 + colossalai/gptq/gptq_utils/gptq.py | 236 ++++++ colossalai/gptq/gptq_utils/quant/__init__.py | 5 + .../gptq/gptq_utils/quant/custom_autotune.py | 194 +++++ .../gptq/gptq_utils/quant/fused_attn.py | 204 +++++ colossalai/gptq/gptq_utils/quant/fused_mlp.py | 288 +++++++ .../gptq/gptq_utils/quant/quant_linear.py | 422 ++++++++++ colossalai/gptq/gptq_utils/quant/quantizer.py | 127 +++ .../gptq/gptq_utils/quant/triton_norm.py | 92 +++ colossalai/gptq/gptq_utils/utils/__init__.py | 3 + colossalai/gptq/gptq_utils/utils/datautils.py | 193 +++++ colossalai/gptq/gptq_utils/utils/export.py | 37 + .../gptq/gptq_utils/utils/modelutils.py | 83 ++ colossalai/gptq/inference_builder.py | 761 +++++++++++++++++ tests/test_gptq/linear_act_fusion_bench.py | 385 +++++++++ tests/test_gptq/quant_llama.py | 569 +++++++++++++ tests/test_gptq/run_gptq.sh | 19 + tests/test_gptq/test_linear_act_fusion.py | 402 +++++++++ tests/test_gptq/test_quant_llama.py | 530 ++++++++++++ 31 files changed, 6920 insertions(+) create mode 100644 colossalai/gptq/__init__.py create mode 100644 colossalai/gptq/cai_gptq/__init__.py create mode 100644 colossalai/gptq/cai_gptq/cai_quant_linear.py create mode 100644 colossalai/gptq/cai_gptq/gptq_autotune.py create mode 100644 colossalai/gptq/cai_gptq/gptq_op.py create mode 100644 colossalai/gptq/cai_gptq/gptq_triton.py create mode 100644 colossalai/gptq/config.py create mode 100644 colossalai/gptq/csrc/gptq_act_linear.cu create mode 100644 colossalai/gptq/csrc/includes/conversion_utils.h create mode 100644 colossalai/gptq/csrc/includes/ds_kernel_utils.h create mode 100644 colossalai/gptq/csrc/includes/inference_cuda_layers.h create mode 100644 colossalai/gptq/csrc/pt_binding.cpp create mode 100644 colossalai/gptq/gptq_utils/__init__.py create mode 100644 colossalai/gptq/gptq_utils/gptq.py create mode 100644 colossalai/gptq/gptq_utils/quant/__init__.py create mode 100644 colossalai/gptq/gptq_utils/quant/custom_autotune.py create mode 100644 colossalai/gptq/gptq_utils/quant/fused_attn.py create mode 100644 colossalai/gptq/gptq_utils/quant/fused_mlp.py create mode 100644 colossalai/gptq/gptq_utils/quant/quant_linear.py create mode 100644 colossalai/gptq/gptq_utils/quant/quantizer.py create mode 100644 colossalai/gptq/gptq_utils/quant/triton_norm.py create mode 100644 colossalai/gptq/gptq_utils/utils/__init__.py create mode 100644 colossalai/gptq/gptq_utils/utils/datautils.py create mode 100644 colossalai/gptq/gptq_utils/utils/export.py create mode 100644 colossalai/gptq/gptq_utils/utils/modelutils.py create mode 100644 colossalai/gptq/inference_builder.py create mode 100644 tests/test_gptq/linear_act_fusion_bench.py create mode 100644 tests/test_gptq/quant_llama.py create mode 100644 tests/test_gptq/run_gptq.sh create mode 100644 tests/test_gptq/test_linear_act_fusion.py create mode 100644 tests/test_gptq/test_quant_llama.py diff --git a/colossalai/gptq/__init__.py b/colossalai/gptq/__init__.py new file mode 100644 index 000000000000..55b7f3c85b2d --- /dev/null +++ b/colossalai/gptq/__init__.py @@ -0,0 +1,5 @@ +from .config import CaiInferenceConfig +from .inference_builder import InferenceBuilder +from torch import nn + + diff --git a/colossalai/gptq/cai_gptq/__init__.py b/colossalai/gptq/cai_gptq/__init__.py new file mode 100644 index 000000000000..2da4309cca0c --- /dev/null +++ b/colossalai/gptq/cai_gptq/__init__.py @@ -0,0 +1,2 @@ +from .gptq_triton import gptq_fused_linear_triton +from .cai_quant_linear import make_cai_quant_linear, CaiQuantLinear diff --git a/colossalai/gptq/cai_gptq/cai_quant_linear.py b/colossalai/gptq/cai_gptq/cai_quant_linear.py new file mode 100644 index 000000000000..f6ba8ab0394b --- /dev/null +++ b/colossalai/gptq/cai_gptq/cai_quant_linear.py @@ -0,0 +1,142 @@ + +import math +import numpy as np +import torch +import torch.nn as nn +from torch.cuda.amp import custom_bwd, custom_fwd +from .gptq_op import CaiGPTQLinearOp +from ..config import CaiInferenceConfig +from .gptq_triton import gptq_linear_llama +import triton + +class CaiQuantLinear(nn.Module): + + def __init__(self, bits, groupsize, infeatures, outfeatures, bias): + super().__init__() + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize if groupsize != -1 else infeatures + + self.register_buffer('qweight', torch.zeros((infeatures // 64 * self.bits, outfeatures), dtype=torch.int64)) + self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 64 * self.bits), dtype=torch.int64)) + self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + # self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int64)) + # self.order_qzeros = torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int64) + # self.register_buffer('input_idx', torch.zeros(infeatures], dtype=torch.int32)) + + + if bias: + self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) + else: + self.bias = None + + cai_inf_config = CaiInferenceConfig(fp16=True, + gptq_group_size=self.groupsize) + self.gptq_linear = CaiGPTQLinearOp(cai_inf_config) + self.printed = False + self.reorder_zeros = False + def pack(self, linear, scales, zeros, g_idx=None): + + + g_idx = g_idx.clone() if g_idx is not None else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + half_scales = scales.clone().half() + # print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape) + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + wn = 16 + pbits = 64 + ptype = torch.int64 + unsign_type = np.uint64 + sign_type = np.int64 + + # wn = 8 + # pbits = 32 + # ptype = torch.int32 + # unsign_type = np.uint32 + # sign_type = np.int32 + + intweight = [] + for idx in range(self.infeatures): + intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(unsign_type) + qweight = np.zeros((intweight.shape[0] // pbits * self.bits, intweight.shape[1]), dtype=unsign_type) + + i = 0 + row = 0 + # print("weight shape ", intweight.shape, qweight.shape, out_qweight.shape, bits) + # print("weight shape ", intweight[0].shape, qweight[0].shape, out_qweight[0].shape) + # print("weight value ", intweight[0], qweight[0]) + + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (pbits // self.bits)): + qweight[row] |= intweight[j] << ( self.bits * (j - i)) + i += pbits // self.bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qweight = qweight.astype(sign_type) + qweight1 = torch.from_numpy(qweight) + qweight1 = qweight1.contiguous() #.to("cuda") + self.qweight.data.copy_(qweight1) + + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type) + zeros -= 1 + zeros = zeros.numpy().astype(unsign_type) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (pbits // self.bits)): + qzeros[:, col] |= zeros[:, j] << ( self.bits * (j - i)) + i += pbits // self.bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qzeros = qzeros.astype(sign_type) + qzeros = torch.from_numpy(qzeros) + qzeros = qzeros #.to(torch.cuda.current_device()) + self.qzeros.data.copy_(qzeros) + + + def forward(self, x): + + # if self.reorder_zeros == False: + # for i in range(self.g_idx.shape[0]): + # idx = self.g_idx[i] + # self.order_qzeros[i,:] = self.qzeros[idx,:] + # gptq_out = gptq_linear_llama(x, self.qweight, self.scales, self.qzeros, self.g_idx, + # self.bits, self.maxq) + + cai_out = self.gptq_linear(x, + self.qweight, + self.scales, + self.qzeros, + bias = self.bias) + print("shape is ", cai_out.shape) + return cai_out + +def make_cai_quant_linear(module, names, bits, groupsize, name=''): + if isinstance(module, CaiQuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + delattr(module, attr) + setattr(module, attr, CaiQuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None)) + for name1, child in module.named_children(): + make_cai_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) + diff --git a/colossalai/gptq/cai_gptq/gptq_autotune.py b/colossalai/gptq/cai_gptq/gptq_autotune.py new file mode 100644 index 000000000000..18e9969d8d00 --- /dev/null +++ b/colossalai/gptq/cai_gptq/gptq_autotune.py @@ -0,0 +1,167 @@ +import math +import time +import torch + +class AutoTune: + def __init__(self, tune_func, warmup=10, bech_run=20): + + self.func = tune_func + self.warmup_num = warmup + self.bech_run = bech_run + self.config_caches = {} + + def prune_configs(self, tune_config): + + norm_configs = [] + linear_configs = [] + + if tune_config['qkv_fused']: + max_in = 2**int(math.log2(tune_config['in_dim'])) + in_dim = tune_config['in_dim'] + else: + max_in = 2**int(math.log2(tune_config['in_dim'])) + in_dim = tune_config['in_dim'] + + max_out = 2**int(math.log2(tune_config['out_dim'])) + if max_out > 1024: + max_out = 1024 + m = 2 + n = 64 + x = 64 + y = 1 + + # while n <= max_out: + # ret_config = { + # "linear_x": n, + # "linear_y": in_dim, + # } + # linear_configs.append(ret_config) + # n = n * 2 + while n <= max_out: + m = 64 + while m < in_dim: + ret_config = { + "linear_x": n, + "linear_y": m, + } + linear_configs.append(ret_config) + m = m * 2 + ret_config = { + "linear_x": n, + "linear_y": in_dim, + } + linear_configs.append(ret_config) + n = n * 2 + + # if tune_config['act_type'] == 0: + # while n <= max_out: + # m = 64 + # while m <= max_in: + # ret_config = { + # "norm_x": 0, + # "norm_y": 0, + # "linear_x": n, + # "linear_y": m, + # } + # linear_configs.append(ret_config) + # m = m * 2 + # n = n * 2 + # elif tune_config['act_type'] > 0: + # while n <= max_out: + # ret_config = { + # "norm_x": 0, + # "norm_y": 0, + # "linear_x": n, + # "linear_y": in_dim, + # } + # linear_configs.append(ret_config) + # n = n * 2 + return linear_configs + + def warmup(self, tune_config, *args, **kwargs): + # if tune_config['qkv_fused']: + # output = torch.zeros(3, tune_config['input_len'], tune_config['out_dim'], + # dtype = torch.float16, device=torch.cuda.current_device()).contiguous() + # else: + # output = torch.zeros(tune_config['input_len'], tune_config['out_dim'], + # dtype = torch.float16, device=torch.cuda.current_device()).contiguous() + + for i in range(0, self.warmup_num): + # out = self.func(*args[:6], output, *args[7:],**kwargs) + out = self.func(*args, **kwargs) + + def benchmark(self, tune_config, *args, **kwargs): + + self.warmup(tune_config, *args, **kwargs) + linear_configs = self.prune_configs(tune_config) + # print(ret_configs) + times = {} + best_norm_x = 512 + best_norm_y = 1 + best_linear_x = 512 + best_linear_y = 256 + # if best_linear_x > tune_config['out_dim']: + # best_linear_x = 2**int(math.log2(tune_config['out_dim'])) + # if best_linear_y > tune_config['in_dim'] and tune_config['qkv_fused'] == False: + # best_linear_y = 2**int(math.log2(tune_config['in_dim'])) + # if best_linear_y > tune_config['in_dim'] and tune_config['qkv_fused']: + # best_linear_y = 2**int(math.log2(tune_config['in_dim'] //3)) + if best_norm_x > tune_config['input_dim']: + best_norm_x = 2**int(math.log2(tune_config['input_dim'])) + + if tune_config['wdtype'] == torch.int8: + nweights = 2 + elif tune_config['wdtype'] == torch.int32: + nweights = 8 + elif tune_config['wdtype'] == torch.int64: + nweights = 16 + times = {} + + # if tune_config['qkv_fused']: + # output = torch.zeros(3, tune_config['input_len'], tune_config['out_dim'], + # dtype = torch.float16, device=torch.cuda.current_device()).contiguous() + # else: + # output = torch.zeros(tune_config['input_len'], tune_config['out_dim'], + # dtype = torch.float16, device=torch.cuda.current_device()).contiguous() + + for config in linear_configs: + linear_x = config['linear_x'] + linear_y = config['linear_y'] + print(config) + start = time.time() + for run in range(0, self.bech_run): + # out = self.func(*args[:6], output, *args[7:-2], linear_x, linear_y) + out = self.func(*args[:-2], linear_x, linear_y) + + torch.cuda.synchronize() + end = time.time() + times[' '.join(map(str,config.values()))] = end - start + # print(f"{config}: {end-start:.6f}") + sorted_dict = sorted(times.items(), key=lambda x:x[1]) + values = sorted_dict[0][0].split() + # print(sorted_dict) + best_linear_x = int(values[0]) + best_linear_y = int(values[1]) + + times = {} + + key = ' '.join(map(str,tune_config.values())) + + ret_config = { + "linear_x": best_linear_x, + "linear_y": best_linear_y, + } + self.config_caches[key] = ret_config + # print("best config:", tune_config, ret_config) + def get_best_config(self, tune_config, *args, **kwargs): + key = ' '.join(map(str,tune_config.values())) + + if key in self.config_caches: + return self.config_caches[key] + else: + # print(tune_config) + self.benchmark(tune_config, *args, **kwargs) + return self.config_caches[key] + + + diff --git a/colossalai/gptq/cai_gptq/gptq_op.py b/colossalai/gptq/cai_gptq/gptq_op.py new file mode 100644 index 000000000000..13af9953395c --- /dev/null +++ b/colossalai/gptq/cai_gptq/gptq_op.py @@ -0,0 +1,101 @@ + +from ..config import CaiInferenceConfig +from ..inference_builder import inference_cuda_module +from .gptq_autotune import AutoTune +from .gptq_triton import gptq_fused_linear_triton +import torch +class BaseOp(torch.nn.Module): + inference_cuda_module = inference_cuda_module + def __init__(self, config: CaiInferenceConfig): + super(BaseOp, self).__init__() + self.config = config + if BaseOp.inference_cuda_module is None: + BaseOp.inference_cuda_module = inference_cuda_module + + +class CaiGPTQLinearOp(BaseOp): + autotune = None + + def __init__(self, config: CaiInferenceConfig): + super(CaiGPTQLinearOp, self).__init__(config) + + self.linear_func = self.inference_cuda_module.gptq_act_linear_fp16 + + self.group_size = config.gptq_group_size + self.bits = config.gptq_quant_bits + self.maxq = 2**self.bits - 1 + self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device()) + if CaiGPTQLinearOp.autotune == None: + CaiGPTQLinearOp.autotune = AutoTune(self.linear_func) + + def forward(self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zeros: torch.Tensor, + act_type = 0, + bias: torch.Tensor = None, + residual: torch.Tensor=None, + qkv_fused = False): + add_bias = True + if bias is None: + bias = self.empty_tensor + add_bias = False + + add_residual = True + if residual is None: + residual = self.empty_tensor + add_residual = False + x = input.view(-1, input.shape[-1]) + + if x.shape[0] > 1: + out = gptq_fused_linear_triton(x, weight, weight_scales, weight_zeros, bias, residual, + self.bits, self.maxq, self.group_size, qkv_fused, add_bias, add_residual, act_type=act_type) + if qkv_fused: + out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1]) + else: + out = out.view(input.shape[0], input.shape[1], weight.shape[-1]) + else: + print("inut shape, ", input.shape) + config = { + "input_dim": input.shape[-1], + "input_len": input.shape[0] * input.shape[1] , + "add_bias": add_bias, + "add_residual": add_residual, + "qkv_fused": qkv_fused, + "act_type": act_type, + "out_dim": weight.shape[-1], + "in_dim": input.shape[-1], + "wdtype": weight.dtype + } + + best_config = CaiGPTQLinearOp.autotune.get_best_config(config, + input, + weight, + weight_scales, + weight_zeros, + bias, + residual, + self.group_size, + act_type, + add_bias, + add_residual, + qkv_fused, + 128, + 128) + block_size_x = best_config['linear_x'] + block_size_y = best_config['linear_y'] + out = self.linear_func(input, + weight, + weight_scales, + weight_zeros, + bias, + residual, + self.group_size, + act_type, + add_bias, + add_residual, + qkv_fused, + block_size_x, + block_size_y) + return out \ No newline at end of file diff --git a/colossalai/gptq/cai_gptq/gptq_triton.py b/colossalai/gptq/cai_gptq/gptq_triton.py new file mode 100644 index 000000000000..44ec38fd80b5 --- /dev/null +++ b/colossalai/gptq/cai_gptq/gptq_triton.py @@ -0,0 +1,781 @@ +import triton +import triton.language as tl +import torch +from ..gptq_utils.quant import custom_autotune +# from ..ops.triton.kernels.activations_kernels import relu, gelu, silu +# code based https://github.com/fpgaminer/GPTQ-triton + # triton.Config({ + # 'BLOCK_SIZE_M': 32, + # 'BLOCK_SIZE_N': 32, + # 'BLOCK_SIZE_K': 128, + # 'GROUP_SIZE_M': 8 + # }, num_stages=2, num_warps=4), + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def cosh(x): + exp_x = tl.exp(x) + return (exp_x + 1.0 / exp_x) * 0.5 + + +# a Triton implementation of the most used activations +# See for instance http://arxiv.org/abs/1606.08415 for an overview + + +# ReLU +@triton.jit +def relu(x): + """ + ReLU_ activation function + + .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html + """ + return tl.where(x >= 0, x, 0.0) + + +@triton.jit +def squared_relu(x): + """ + Squared ReLU activation, as proposed in the Primer_ paper. + + .. _Primer: https://arxiv.org/abs/2109.08668 + """ + x_sq = x * x + return tl.where(x > 0.0, x_sq, 0.0) + + +@triton.jit +def star_relu(x): + """ + Star ReLU activation, as proposed in the "MetaFormer Baselines for Vision"_ paper. + + .. _ "MetaFormer Baselines for Vision": https://arxiv.org/pdf/2210.13452.pdf + """ + x_sq = x * x + return 0.8944 * tl.where(x > 0.0, x_sq, 0.0) - 0.4472 + + +# Leaky ReLU +@triton.jit +def leaky_relu(x): + """ + LeakyReLU_ activation + + .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html + """ + return tl.where(x >= 0.0, x, 0.01 * x) + + +@triton.jit +def gelu(x): + """ + GeLU_ activation - Gaussian error linear unit + + .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf + """ + return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x))) + + +@triton.jit +def smelu(x): + """ + SmeLU_ activation - Smooth ReLU with beta=2.0 + + .. _SmeLU: https://arxiv.org/pdf/2202.06499.pdf + """ + beta = 2.0 + + relu = tl.where(x >= beta, x, 0.0) + return tl.where( + tl.abs(x) <= beta, (x + beta) * (x + beta) / (4.0 * beta), relu) + + +@triton.jit +def silu(x): + return x*tl.sigmoid(x) + +@custom_autotune.autotune( + configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=3, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], + nearest_power_of_two=True, + prune_configs_by={ + 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, + 'perf_model': None, + 'top_k': None, + }, +) +@triton.jit +def cai_gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_ptr, residual_ptr, + M, N, K, bits, maxq, gptq_group_size, + stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, + QKV_FUSED: tl.constexpr, ADD_BIAS: tl.constexpr, ADD_RESIDUAL:tl.constexpr, ACT_TYPE:tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//16, N) int64 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + """ + infearure_per_bits = 64 // bits + + pid = tl.program_id(axis=0) + NK = K + + # if QKV_FUSED: + # NK = K//3 + # else: + # NK = K + # NK = K + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K) + qkv_offset = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + # offs_bk = offs_k + qkv_offset * NK + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + qkv_offset * N * NK //infearure_per_bits + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + # g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + qkv_offset * NK * N //gptq_group_size + offs_bn[None, :] + zeros_ptrs = zeros_ptr + qkv_offset * NK * N //gptq_group_size//infearure_per_bits + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + g_idx_base = tl.arange(0, BLOCK_SIZE_K) + g_idx_base = g_idx_base // gptq_group_size + g_idx = g_idx_base + # tl.device_print("gidx, ", g_idx) + + currend_group_end = 0 + for k in range(0, num_pid_k): + # g_idx = tl.load(g_ptrs) + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros).to(tl.float16) * scales # Scale and shift + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_idx = g_idx_base + ((k + 1) * BLOCK_SIZE_K) // gptq_group_size + # if (k + 2) * BLOCK_SIZE_K > currend_group_end: + + c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + + if ADD_BIAS: + bias_mask = (offs_bn < N) + offs_bn += qkv_offset * N + bias_ptrs = bias_ptr + stride_cn * offs_bn + bias = tl.load(bias_ptrs, mask=bias_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + accumulator += bias[None, :] + + + if ACT_TYPE == 1: + accumulator=relu(accumulator) + elif ACT_TYPE == 2: + accumulator=gelu(accumulator) + elif ACT_TYPE == 3: + accumulator=silu(accumulator) + + + if ADD_RESIDUAL: + residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + res = tl.load(residual_ptrs, mask=c_mask, other=0.) + accumulator += res + + tl.store(c_ptrs, accumulator, mask=c_mask) + + + +@custom_autotune.autotune( + configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=3, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], + nearest_power_of_two=True, + prune_configs_by={ + 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, + 'perf_model': None, + 'top_k': None, + }, +) +@triton.jit +def cai_gptq_v2_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_ptr, residual_ptr, + M, N, K, bits, maxq, gptq_group_size, + stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, + QKV_FUSED: tl.constexpr, ADD_BIAS: tl.constexpr, ADD_RESIDUAL:tl.constexpr, ACT_TYPE:tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//16, N) int64 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + """ + infearure_per_bits = 64 // bits + + pid = tl.program_id(axis=0) + NK = K + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K) + qkv_offset = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + # offs_bk = offs_k + qkv_offset * NK + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + qkv_offset * N * NK //infearure_per_bits + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + # g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + qkv_offset * NK * N //gptq_group_size + offs_bn[None, :] + zeros_ptrs = zeros_ptr + qkv_offset * NK * N //gptq_group_size//infearure_per_bits + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + g_idx_base = tl.arange(0, BLOCK_SIZE_K) + g_idx_base = g_idx_base // gptq_group_size + g_idx = g_idx_base + # tl.device_print("gidx, ", g_idx) + + currend_group_end = gptq_group_size + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + for k in range(0, num_pid_k): + # g_idx = tl.load(g_ptrs) + if (k + 1) * BLOCK_SIZE_K > currend_group_end: + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros).to(tl.float16) * scales # Scale and shift + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_idx = g_idx_base + ((k + 1) * BLOCK_SIZE_K) // gptq_group_size + # if (k + 2) * BLOCK_SIZE_K > currend_group_end: + + c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + + if ADD_BIAS: + bias_mask = (offs_bn < N) + offs_bn += qkv_offset * N + bias_ptrs = bias_ptr + stride_cn * offs_bn + bias = tl.load(bias_ptrs, mask=bias_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + accumulator += bias[None, :] + + + if ACT_TYPE == 1: + accumulator=relu(accumulator) + elif ACT_TYPE == 2: + accumulator=gelu(accumulator) + elif ACT_TYPE == 3: + accumulator=silu(accumulator) + + + if ADD_RESIDUAL: + residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + res = tl.load(residual_ptrs, mask=c_mask, other=0.) + accumulator += res + + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@custom_autotune.autotune( + configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=3, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], + nearest_power_of_two=True, + prune_configs_by={ + 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, + 'perf_model': None, + 'top_k': None, + }, +) +@triton.jit +def cai_gptq_idx_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, idx_ptr, bias_ptr, residual_ptr, + M, N, K, bits, maxq, gptq_group_size, + stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, + QKV_FUSED: tl.constexpr, ADD_BIAS: tl.constexpr, ADD_RESIDUAL:tl.constexpr, ACT_TYPE:tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//16, N) int64 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + """ + infearure_per_bits = 64 // bits + + pid = tl.program_id(axis=0) + NK = K + + # if QKV_FUSED: + # NK = K//3 + # else: + # NK = K + # NK = K + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K) + qkv_offset = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + # offs_bk = offs_k + qkv_offset * NK + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + qkv_offset * N * NK //infearure_per_bits + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + # g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + qkv_offset * NK * N //gptq_group_size + offs_bn[None, :] + zeros_ptrs = zeros_ptr + qkv_offset * NK * N //gptq_group_size//infearure_per_bits + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + g_ptrs = idx_ptr + offs_bk + g_idx = tl.load(g_ptrs) + # tl.device_print("gidx, ", g_idx) + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + + currend_group_end = gptq_group_size + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + + for k in range(0, num_pid_k): + # g_idx = tl.load(g_ptrs) + if (k + 1) * BLOCK_SIZE_K > currend_group_end: + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros).to(tl.float16) * scales # Scale and shift + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + + if ADD_BIAS: + bias_mask = (offs_bn < N) + offs_bn += qkv_offset * N + bias_ptrs = bias_ptr + stride_cn * offs_bn + bias = tl.load(bias_ptrs, mask=bias_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + accumulator += bias[None, :] + + + if ACT_TYPE == 1: + accumulator=relu(accumulator) + elif ACT_TYPE == 2: + accumulator=gelu(accumulator) + elif ACT_TYPE == 3: + accumulator=silu(accumulator) + + + if ADD_RESIDUAL: + residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + res = tl.load(residual_ptrs, mask=c_mask, other=0.) + accumulator += res + + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def gptq_fused_linear_triton(input, qweight, scales, qzeros, bias, residual, + bits, maxq, gptq_group_size, qkv_fused, add_bias, add_residual, idx = None, act_type = 0): + # print("gptq fused ", qkv_fused, add_bias, add_residual) + with torch.cuda.device(input.device): + if qkv_fused: + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']) * 3, ) + output = torch.empty((input.shape[0]*3, qweight.shape[1]), device=input.device, dtype=torch.float16) + else: + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), ) + output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) + # print("dtype, ", qweight.dtype, output.dtype, scales.dtype, qzeros.dtype, bias.dtype, residual.dtype) + if idx is None: + cai_gptq_v2_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, bias, residual, + input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, + gptq_group_size, + input.stride(0), input.stride(1), qweight.stride(0), + qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0), + QKV_FUSED=qkv_fused, ADD_BIAS=add_bias, ADD_RESIDUAL=add_residual, ACT_TYPE=act_type) + else: + cai_gptq_idx_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, idx, bias, residual, + input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, + gptq_group_size, + input.stride(0), input.stride(1), qweight.stride(0), + qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0), + QKV_FUSED=qkv_fused, ADD_BIAS=add_bias, ADD_RESIDUAL=add_residual, ACT_TYPE=act_type) + if qkv_fused: + return output.view(3, input.shape[0], qweight.shape[1]) + else: + return output + + + +# code based https://github.com/fpgaminer/GPTQ-triton +@custom_autotune.autotune( + configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=3, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], + nearest_power_of_two=True, + prune_configs_by={ + 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, + 'perf_model': None, + 'top_k': None, + }, +) +@triton.jit +def gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_bn[None, :] + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, num_pid_k): + g_idx = tl.load(g_ptrs) + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def gptq_linear_llama(x, qweight, scales, qzeros, g_idx, + bits, maxq): + + out_shape = x.shape[:-1] + (qweight.shape[-1], ) + input = x.reshape(-1, x.shape[-1]) + # print("input shape:", input.shape)/ + with torch.cuda.device(input.device): + output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), ) + gptq_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], + qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), + qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) + # output = output.reshape(out_shape) + + return output.reshape(out_shape) \ No newline at end of file diff --git a/colossalai/gptq/config.py b/colossalai/gptq/config.py new file mode 100644 index 000000000000..fc2dd4cb9661 --- /dev/null +++ b/colossalai/gptq/config.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import json +import torch +from enum import IntEnum + +DEFAULT_INTERMEDIATE_SIZE = -1 +class ActivationFuncType(IntEnum): + UNKNOWN = 0 + ReLU = 1 + GELU = 2 + SiLU = 3 + GATED_GELU = 4 + GATED_SILU = 5 + + +class CaiInferenceConfig(): + + + def __init__(self, + fp16=True, + gptq=False, + gptq_group_size=128, + gptq_quant_bits=4, + gptq_weight_dtype=torch.int64 + ): + self.fp16 = fp16 + self.gptq = gptq + self.gptq_group_size = gptq_group_size + self.gptq_quant_bits = gptq_quant_bits + self.gptq_weight_dtype = gptq_weight_dtype + + diff --git a/colossalai/gptq/csrc/gptq_act_linear.cu b/colossalai/gptq/csrc/gptq_act_linear.cu new file mode 100644 index 000000000000..b622233df89e --- /dev/null +++ b/colossalai/gptq/csrc/gptq_act_linear.cu @@ -0,0 +1,387 @@ +#include "conversion_utils.h" +#include "inference_cuda_layers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#define SHARE_MEM_SIZE (48 * 1024) +inline __device__ float relu(const float x) { return x < 0 ? 0 : x; } +inline __device__ float gelu(const float x) +{ + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); +} +inline __device__ float silu(const float x) +{ + return x / (1 + expf(-x)); +} +/*** +input: the input size is [b, l, m] +weight: the weight size is [m/size(TW)*2, n] +weight_scales: the weight scales size is [m/group_size, n] +weight_zeros: the weight scales size is [m/group_size, n/size(TW)*2] +bias: linear bias [n] +input_dim0: m +input_dim1: b * l +weight_dim0: n +weight_dim1: m +block_size_m: m for one gpu thread block +block_size_n: n for one gpu thread block +the computation block is [block_size_m, block_size_n] for one gpu thread block +group_size: the group size for gptq quant +add_bias: the linear has bias or not +***/ +template +__global__ void gptq_gemm(T* input, + TW* weight, + T* weight_scales, + TW* weight_zeros, + T* bias, + T* residual, + T* output, + uint64_t input_dim0, + uint64_t input_dim1, + uint64_t weight_dim0, + uint64_t weight_dim1, + uint64_t group_size, + int32_t act_type, + bool add_bias, + bool add_residual, + bool qkv_fused, + uint64_t block_size_m, + uint64_t block_size_n) +{ + const uint32_t n_weights = sizeof(TW) * 2; // number of compressed weights in a TW. + + uint64_t block_offset = blockIdx.x; + uint64_t local_tid = threadIdx.x; + uint64_t block_tnum = blockDim.x; + uint64_t block_m_start = blockIdx.y * block_size_m; + uint64_t block_m_end = (blockIdx.y + 1) * block_size_m; + block_m_end = std::min(block_m_end, weight_dim1 * n_weights); + + uint64_t group_step = 32; + group_step = group_step - group_step % n_weights; + group_step = std::min(group_step, group_size); + + uint64_t group_block = group_step / n_weights; + uint64_t table_iter = (group_step / 2 * 256) / block_tnum; + uint64_t col_offset = block_size_n * block_offset; + + __shared__ float table[16][256]; // look-up table, for 32 inputs + __shared__ float i2sum[16]; + return; + + uint64_t qkv_offset = 0; + uint64_t qkv_out_base_offset = col_offset; + uint64_t bias_base_offset = col_offset; + uint64_t split_m_size = weight_dim1 * n_weights; + if (qkv_fused) + { + split_m_size = weight_dim1 * n_weights / 3; + qkv_offset = block_m_start / split_m_size; + qkv_out_base_offset = qkv_offset * input_dim1 * weight_dim0 + col_offset; + bias_base_offset = qkv_offset * weight_dim0 + col_offset; + } + + float tmp_w_res = conversion::to(0.0); + float tmp_z_res = conversion::to(0.0); + float tmp_final_res = conversion::to(0.0); + float tmp_weight_scales; + float tmp_weight_zero; + + uint64_t current_group_size = group_size; + uint64_t scale_dim1_ind = block_m_start / group_size; + + for (uint64_t i = block_m_start; i < block_m_end; i += current_group_size) + { + if (i + current_group_size > block_m_end) + current_group_size = block_m_end - i; + + // // index of weight scale + // uint64_t dind = (i / group_size) * weight_dim0 + col_offset + local_tid; + // int32_t i_zero = + // ((weight_zeros[dind / n_weights] >> (((col_offset + local_tid) & 0xf) * 4)) & 0xf) + 1; + + // tmp_weight_scales = conversion::to(weight_scales[dind]); + // tmp_weight_zero = conversion::to(i_zero); + // if (i + current_group_size > block_m_end) + // current_group_size = block_m_end - i; + + // index of weight scale + uint64_t scale_index = + scale_dim1_ind * weight_dim0 + col_offset + local_tid; + // 4 is 4bits weight. 0xf is mask for 4 bits weight. 1 is for gptq algorithm. + int32_t i_zero = ((weight_zeros[scale_index / n_weights] >> + ((scale_index & 0xf) * 4)) & + 0xf) + + 1; + + tmp_weight_scales = conversion::to(weight_scales[scale_index]); + tmp_weight_zero = conversion::to(i_zero); + scale_dim1_ind += 1; + for (uint64_t j = 0; j < current_group_size; j += group_step) + { + +// compute lookup table +#pragma unroll + for (uint64_t k = 0; k < table_iter; k++) + { + + // uint64_t table_id = k * block_tnum + local_tid; + // uint64_t dind = table_id & 0xff; + // uint64_t tid = table_id >> 8; + // uint64_t input_offset = i + j + tid * 2; + + uint64_t table_id = k * block_tnum + local_tid; + uint64_t weight_id = table_id & 0xff; + uint64_t input_id = table_id >> 8; + // 2 is number of inputs for one table elements. + uint64_t input_offset = (i + j + input_id * 2) % split_m_size; + + // float i1, i2; + + // float i1 = relu(conversion::to(input[input_offset])); + // float i2 = relu(conversion::to(input[input_offset + 1])); + float i1 = (conversion::to(input[input_offset])); + float i2 = (conversion::to(input[input_offset + 1])); + + i2sum[input_id] = i1 + i2; + + int32_t iw1 = weight_id & 0xf; + int32_t iw2 = weight_id >> 4; + + float w1 = conversion::to(iw1); + float w2 = conversion::to(iw2); + + table[input_id][weight_id] = w1 * i1 + w2 * i2; + } + __syncthreads(); +#pragma unroll + for (uint64_t k = 0; k < group_block; k++) + { + + uint64_t base_weight_offset = ((i + j) / n_weights + k) * weight_dim0; + uint64_t dind = base_weight_offset + col_offset + local_tid; + + TW w = weight[dind]; + +#pragma unroll + for (uint64_t z = 0; z < n_weights / 2; z++) + { + uint32_t k1 = k * n_weights / 2 + z; + TW w1 = (w >> (z * 8)) & 0xff; + + tmp_w_res += table[k1][w1]; + tmp_z_res += i2sum[k1]; + } + } + } + + tmp_final_res += + (tmp_w_res - tmp_z_res * tmp_weight_zero) * tmp_weight_scales; + tmp_w_res = conversion::to(0.0); + tmp_z_res = conversion::to(0.0); + } + + + if(col_offset + local_tid < input_dim0 * input_dim1) + { + uint64_t bias_offset = bias_base_offset + local_tid; + float bias_v = 0; + float residual_v = 0; + if (add_bias && blockIdx.y == 0) + { + bias_v = conversion::to(bias[bias_offset]); + tmp_final_res += bias_v; + } + uint64_t dind = qkv_out_base_offset + local_tid; + if(act_type == 1) + { + tmp_final_res = relu(tmp_final_res); + } + else if(act_type == 2) + { + tmp_final_res = gelu(tmp_final_res); + } + else if(act_type == 3) + { + tmp_final_res = silu(tmp_final_res); + } + + if (add_residual && blockIdx.y == 0){ + residual_v = conversion::to(residual[dind]); + tmp_final_res += residual_v; + } + T tmp_res = conversion::to(tmp_final_res); + // float *o = (float*)output; + atomicAdd(&output[dind], tmp_res); + // atomicAdd(&output[dind], tmp_final_res); + + } +} + + + +template +at::Tensor gptq_act_linear_layer(at::Tensor& input, + at::Tensor& weight, + at::Tensor& weight_scales, + at::Tensor& weight_zeros, + at::Tensor& bias, + at::Tensor& residual, + int64_t group_size, + int32_t act_type, + int32_t add_bias, + int32_t add_residual, + int32_t qkv_fused, + uint64_t block_size_x, + uint64_t block_size_y) +{ + + uint64_t input_dim0 = input.sizes()[2]; + uint64_t input_dim1 = input.sizes()[0] * input.sizes()[1]; + + uint64_t weight_dim0 = weight.sizes()[1]; + uint64_t weight_dim1 = weight.sizes()[0]; + + + auto options = + torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA); + + std::vector out_shape; + if (qkv_fused) + out_shape.push_back(3); + out_shape.push_back(input.sizes()[0]); + out_shape.push_back(input.sizes()[1]); + out_shape.push_back(weight.sizes()[1]); + + at::Tensor output = at::zeros(out_shape, options); + + T* input_ptr = (T*)input.data_ptr(); + TW* weight_ptr = (TW*)weight.data_ptr(); + T* weight_scales_ptr = (T*)weight_scales.data_ptr(); + TW* weight_zeros_ptr = (TW*)weight_zeros.data_ptr(); + T* bias_ptr = (T*)bias.data_ptr(); + T* output_ptr = (T*)output.data_ptr(); + T* residual_ptr = (T*)residual.data_ptr(); + // at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(); + auto stream = at::cuda::getCurrentCUDAStream().stream(); +// #define BENCHMARK +#ifdef BENCHMARK + uint32_t block_xs[] = {128, 256, 512}; + uint32_t block_ys[] = {128, 256, 512, 1024}; + + for (uint32_t i = 0; i < 3; i++) + { + for (uint32_t j = 0; j < 4; j++) + { + + block_size_x = block_xs[i]; + block_size_y = block_ys[j]; + uint32_t warm_up = 2; + uint32_t bench = 5; + auto start = std::chrono::high_resolution_clock::now(); + auto end = std::chrono::high_resolution_clock::now(); + for (uint32_t k = 0; k < warm_up + bench; k++) + { + + if (k == warm_up) + start = std::chrono::high_resolution_clock::now(); + +#endif + uint64_t block_size_m = block_size_y; + uint64_t block_size_n = block_size_x; + + uint64_t block_tnum = block_size_x; + + // printf("block size m %d %d\n", weight_dim1, weight_dim0); + // printf("block size m %d %d\n", input_dim1, input_dim0); + + if (input_dim1 == 1) + { + + dim3 block_dim(block_tnum, 1, 1); + dim3 grid_dim(weight_dim0 / block_tnum, + (weight_dim1 * sizeof(TW) * 2 + block_size_y - 1) / block_size_y, + 1); + // printf("block size m %d %d\n", weight_dim1, weight_dim0); + // printf("block size m %d %d\n", input_dim1, input_dim0); + // printf("block size m %d %d\n", block_tnum, weight_dim0 / block_tnum); + // printf("block size m %d %d\n", weight_dim1 * sizeof(TW) * 2 / block_size_y, input_dim0); + gptq_gemm + <<>>(input_ptr, + weight_ptr, + weight_scales_ptr, + weight_zeros_ptr, + bias_ptr, + residual_ptr, + output_ptr, + input_dim0, + input_dim1, + weight_dim0, + weight_dim1, + group_size, + act_type, + add_bias, + add_residual, + qkv_fused, + block_size_m, + block_size_n); + } + else + { + printf("cuda kernel not support batch * seq_len > 1\n"); + } + +#ifdef BENCHMARK + } + end = std::chrono::high_resolution_clock::now(); + double sec = + (double)(std::chrono::duration_cast( + end - start) + .count()) / + 1e9 / 5; + + printf("block x: %d, block y: %d, %.8f\n", + block_size_x, + block_size_y, + sec); + } + } +#endif + // float *o = (float*)output_ptr; + // for(int i = 0; i < 64; i ++){ + // printf("%f ", o[i]); + // } + // printf("\n"); + return output; +} + +#define INSTANTIATE_ACT_GPTQ_LINEAR(T, TW) \ + template at::Tensor gptq_act_linear_layer( \ + at::Tensor & input, \ + at::Tensor & weight, \ + at::Tensor & weight_scales, \ + at::Tensor & weight_zeros, \ + at::Tensor & bias, \ + at::Tensor & residual, \ + int64_t group_size, \ + int32_t act_type, \ + int32_t add_bias, \ + int32_t add_residual, \ + int32_t qkv_fused, \ + uint64_t block_size_x, \ + uint64_t block_size_y); + +// INSTANTIATE_ACT_GPTQ_LINEAR(float, uint64_t) +INSTANTIATE_ACT_GPTQ_LINEAR(__half, uint64_t) +// INSTANTIATE_ACT_GPTQ_LINEAR(float, uint32_t) +INSTANTIATE_ACT_GPTQ_LINEAR(__half, uint32_t) +// INSTANTIATE_ACT_GPTQ_LINEAR(float, uint8_t) +INSTANTIATE_ACT_GPTQ_LINEAR(__half, uint8_t) diff --git a/colossalai/gptq/csrc/includes/conversion_utils.h b/colossalai/gptq/csrc/includes/conversion_utils.h new file mode 100644 index 000000000000..3d31e37de364 --- /dev/null +++ b/colossalai/gptq/csrc/includes/conversion_utils.h @@ -0,0 +1,641 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include "ds_kernel_utils.h" + +#include +#include + +#ifdef BF16_AVAILABLE +#include +#endif + +namespace conversion { + +// Basic primitive for constructing conversions +template +DS_D_INLINE TO to(FROM val) +{ + return to(val); +} + +// Specializations + +/********************* Identity Conversions *********************/ +/* +Identity conversions are useful in templated functions where we might have +a fixed destination type. For example, I might have a kernel that accepts +__half, __nv_bfloat16, and float but always want to do the core computation +at floating point: + +T mem_value = input[idx]; +float compute_value = conversion::to(mem_value); + +In practice, we should be able to elide the second template parameter: +float compute_val = conversion::to(mem_value); + +In this case, we need an implementation to handle the T = float case + +NOTE: The type inferencing system appears to be unable to handle inferring the first +template parameter, even in the trivial case. +*/ + +// Floating point types +template <> +DS_D_INLINE double to(double val) +{ + return val; +} +template <> +DS_D_INLINE float to(float val) +{ + return val; +} +template <> +DS_D_INLINE __half to(__half val) +{ + return val; +} +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 to(__nv_bfloat16 val) +{ + return val; +} +#endif + +// Integer types +template <> +DS_D_INLINE int8_t to(int8_t val) +{ + return val; +} +template <> +DS_D_INLINE uint8_t to(uint8_t val) +{ + return val; +} +template <> +DS_D_INLINE int16_t to(int16_t val) +{ + return val; +} +template <> +DS_D_INLINE uint16_t to(uint16_t val) +{ + return val; +} +template <> +DS_D_INLINE int32_t to(int32_t val) +{ + return val; +} +template <> +DS_D_INLINE uint32_t to(uint32_t val) +{ + return val; +} +template <> +DS_D_INLINE int64_t to(int64_t val) +{ + return val; +} +template <> +DS_D_INLINE uint64_t to(uint64_t val) +{ + return val; +} + +// TODO: evaluate if we want bools + +/********************* To Double Conversions *********************/ + +// * to double variants + +// Would normally like to not use C cast, but this is an important enough conversion +// to keep +template <> +DS_D_INLINE double to(float val) +{ +#ifdef PTX_AVAILABLE + double ret_val; + asm("ctv.rn.f64.f32 %0, %1;\n" : "=d"(ret_val) : "f"(val)); + return ret_val; +#else + return double(val); +#endif +} +// Note: there is a CVT instruction for __half -> double, but there's no inline interface +// for passing a single half value +template <> +DS_D_INLINE double to(__half val) +{ + return to(__half2float(val)); +} +template <> +DS_D_INLINE double to(int64_t val) +{ + return __ll2double_rn(val); +} +template <> +DS_D_INLINE double to(int32_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(int16_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(int8_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(uint64_t val) +{ + return __ull2double_rn(val); +} +template <> +DS_D_INLINE double to(uint32_t val) +{ + return __uint2double_rn(val); +} +template <> +DS_D_INLINE double to(uint16_t val) +{ + return __uint2double_rn(val); +} +template <> +DS_D_INLINE double to(uint8_t val) +{ + return __uint2double_rn(val); +} + +// Same applies here +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE double to(__nv_bfloat16 val) +{ + return to(__bfloat162float(val)); +} +#endif + +/********************* To Float Conversions *********************/ + +template <> +DS_D_INLINE float to(double val) +{ + return __double2float_rn(val); +} +template <> +DS_D_INLINE float to(__half val) +{ + return __half2float(val); +} +template <> +DS_D_INLINE float to(int64_t val) +{ + return __ll2float_rn(val); +} +template <> +DS_D_INLINE float to(int32_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(int16_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(int8_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(uint64_t val) +{ + return __ull2float_rn(val); +} +template <> +DS_D_INLINE float to(uint32_t val) +{ + return __uint2float_rn(val); +} +template <> +DS_D_INLINE float to(uint16_t val) +{ + return __uint2float_rn(val); +} +template <> +DS_D_INLINE float to(uint8_t val) +{ + return __uint2float_rn(val); +} + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE float to(__nv_bfloat16 val) +{ + return __bfloat162float(val); +} +#endif + +/********************* To Float2 Conversions *********************/ +template <> +DS_D_INLINE float2 to(__half2 val) +{ + return __half22float2(val); +} + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE float2 to(__nv_bfloat162 val) +{ + return __bfloat1622float2(val); +} +#endif + +/********************* To Half Conversions *********************/ +template <> +DS_D_INLINE __half to(double val) +{ +#ifdef __HIP_PLATFORM_HCC__ + float val_f = __double2float_rn(val); + return __float2half(val_f); +#else + return __double2half(val); +#endif +} +template <> +DS_D_INLINE __half to(float val) +{ + return __float2half(val); +} +template <> +DS_D_INLINE __half to(int64_t val) +{ + return __ll2half_rn(val); +} +template <> +DS_D_INLINE __half to(int32_t val) +{ + return __int2half_rn(val); +} +template <> +DS_D_INLINE __half to(int16_t val) +{ + return __short2half_rn(val); +} +template <> +DS_D_INLINE __half to(int8_t val) +{ + return __int2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint64_t val) +{ + return __ull2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint32_t val) +{ + return __uint2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint16_t val) +{ + return __ushort2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint8_t val) +{ + return __uint2half_rn(val); +} + +#ifdef BF16_AVAILABLE +// No direct conversion +template <> +DS_D_INLINE __half to(__nv_bfloat16 val) +{ + return to<__half>(to(val)); +} +#endif + +/********************* To Half2 Conversions *********************/ +template <> +DS_D_INLINE __half2 to(float2 val) +{ + return __float22half2_rn(val); +} +template <> +DS_D_INLINE __half2 to(float val) +{ + return __float2half2_rn(val); +} + +#ifdef BF16_AVAILABLE +// No direct conversion +template <> +DS_D_INLINE __half2 to(__nv_bfloat162 val) +{ + return to<__half2>(to(val)); +} +#endif + +/********************* To BF16 Conversions *********************/ +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 to(double val) +{ + return __double2bfloat16(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(float val) +{ + return __float2bfloat16(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int64_t val) +{ + return __ll2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int32_t val) +{ + return __int2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int16_t val) +{ + return __short2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int8_t val) +{ + return __int2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint64_t val) +{ + return __ull2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint32_t val) +{ + return __uint2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint16_t val) +{ + return __ushort2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint8_t val) +{ + return __uint2bfloat16_rn(val); +} +#endif + +/********************* To BF162 Conversions *********************/ +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat162 to(float2 val) +{ + return __float22bfloat162_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat162 to(float val) +{ + return __float2bfloat162_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat162 to(__half2 val) +{ + return to<__nv_bfloat162>(to(val)); +} +#endif + +/********************* To INT64_T Conversions *********************/ +template <> +DS_D_INLINE int64_t to(double val) +{ + return __double2ll_rn(val); +} +template <> +DS_D_INLINE int64_t to(float val) +{ + return __float2ll_rn(val); +} +template <> +DS_D_INLINE int64_t to(__half val) +{ + return __half2ll_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int64_t to(__nv_bfloat16 val) +{ + return __bfloat162ll_rn(val); +} +#endif + +/********************* To INT32_T Conversions *********************/ +template <> +DS_D_INLINE int32_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int32_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int32_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int32_t to(__nv_bfloat16 val) +{ + return __bfloat162int_rn(val); +} +#endif + +/********************* To INT16_T Conversions *********************/ +template <> +DS_D_INLINE int16_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int16_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int16_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int16_t to(__nv_bfloat16 val) +{ + return __bfloat162int_rn(val); +} +#endif + +/********************* To INT8_T Conversions *********************/ +template <> +DS_D_INLINE int8_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int8_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int8_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int8_t to(__nv_bfloat16 val) +{ + return __bfloat162int_rn(val); +} +#endif + +/********************* To UINT64_T Conversions *********************/ +template <> +DS_D_INLINE uint64_t to(double val) +{ + return __double2ull_rn(val); +} +template <> +DS_D_INLINE uint64_t to(float val) +{ + return __float2ull_rn(val); +} +template <> +DS_D_INLINE uint64_t to(__half val) +{ + return __half2ull_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint64_t to(__nv_bfloat16 val) +{ + return __bfloat162ull_rn(val); +} +#endif + +/********************* To UINT32_T Conversions *********************/ +template <> +DS_D_INLINE uint32_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint32_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint32_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint32_t to(__nv_bfloat16 val) +{ + return __bfloat162uint_rn(val); +} +#endif + +/********************* To UINT16_T Conversions *********************/ +template <> +DS_D_INLINE uint16_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint16_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint16_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint16_t to(__nv_bfloat16 val) +{ + return __bfloat162uint_rn(val); +} +#endif + +/********************* To UINT8_T Conversions *********************/ +template <> +DS_D_INLINE uint8_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint8_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint8_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint8_t to(__nv_bfloat16 val) +{ + return __bfloat162uint_rn(val); +} +#endif + +} // namespace conversion \ No newline at end of file diff --git a/colossalai/gptq/csrc/includes/ds_kernel_utils.h b/colossalai/gptq/csrc/includes/ds_kernel_utils.h new file mode 100644 index 000000000000..99d8be17e503 --- /dev/null +++ b/colossalai/gptq/csrc/includes/ds_kernel_utils.h @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +/* +Centralized header file for preprocessor macros and constants +used throughout the codebase. +*/ + +#pragma once + +#include + +#define DS_HD_INLINE __host__ __device__ __forceinline__ +#define DS_D_INLINE __device__ __forceinline__ + +#ifdef __HIP_PLATFORM_HCC__ + +// constexpr variant of warpSize for templating +constexpr int hw_warp_size = 64; +#define HALF_PRECISION_AVAILABLE = 1 +#include + +#else // !__HIP_PLATFORM_HCC__ + +// constexpr variant of warpSize for templating +constexpr int hw_warp_size = 32; + +#if __CUDA_ARCH__ >= 530 +#define HALF_PRECISION_AVAILABLE = 1 +#define PTX_AVAILABLE +#endif // __CUDA_ARCH__ >= 530 + +#if __CUDA_ARCH__ >= 800 +#define ASYNC_COPY_AVAILABLE +#define BF16_AVAILABLE +#endif // __CUDA_ARCH__ >= 800 + +#include + +#endif //__HIP_PLATFORM_HCC__ + +inline int next_pow2(const int val) +{ + int rounded_val = val - 1; + rounded_val |= rounded_val >> 1; + rounded_val |= rounded_val >> 2; + rounded_val |= rounded_val >> 4; + rounded_val |= rounded_val >> 8; + return rounded_val + 1; +} \ No newline at end of file diff --git a/colossalai/gptq/csrc/includes/inference_cuda_layers.h b/colossalai/gptq/csrc/includes/inference_cuda_layers.h new file mode 100644 index 000000000000..08ea00c6558c --- /dev/null +++ b/colossalai/gptq/csrc/includes/inference_cuda_layers.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#ifdef BF16_AVAILABLE +#include +#endif +#include +#include +#include +#include +#include +#include + +template +at::Tensor gptq_act_linear_layer(at::Tensor& input, + at::Tensor& weight, + at::Tensor& weight_scales, + at::Tensor& weight_zeros, + at::Tensor& bias, + at::Tensor& residual, + int64_t group_size, + int32_t act_type, + int32_t add_bias, + int32_t add_residual, + int32_t qkv_fused, + uint64_t block_size_x, + uint64_t block_size_y); \ No newline at end of file diff --git a/colossalai/gptq/csrc/pt_binding.cpp b/colossalai/gptq/csrc/pt_binding.cpp new file mode 100644 index 000000000000..80e96fde6673 --- /dev/null +++ b/colossalai/gptq/csrc/pt_binding.cpp @@ -0,0 +1,23 @@ +#include "inference_cuda_layers.h" +#include +#include +#include +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + + m.def("gptq_act_linear_fp16", + &gptq_act_linear_layer<__half, uint64_t>, + "gptq linear kernel (CUDA)"); + + + m.def("gptq_act_linear_fp16_w32", + &gptq_act_linear_layer<__half, uint32_t>, + "gptq linear kernel (CUDA)"); + + m.def("gptq_act_linear_fp16_w8", + &gptq_act_linear_layer<__half, uint8_t>, + "gptq linear kernel (CUDA)"); + +} diff --git a/colossalai/gptq/gptq_utils/__init__.py b/colossalai/gptq/gptq_utils/__init__.py new file mode 100644 index 000000000000..2b9db6637df3 --- /dev/null +++ b/colossalai/gptq/gptq_utils/__init__.py @@ -0,0 +1 @@ +from .gptq import GPTQ, Observer \ No newline at end of file diff --git a/colossalai/gptq/gptq_utils/gptq.py b/colossalai/gptq/gptq_utils/gptq.py new file mode 100644 index 000000000000..e17c0c47c6d7 --- /dev/null +++ b/colossalai/gptq/gptq_utils/gptq.py @@ -0,0 +1,236 @@ +import math +import time + +import torch +import torch.nn as nn +import transformers +from .quant import Quantizer +from texttable import Texttable +from .utils import torch_snr_error + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + + +class Observer: + + def __init__(self, topk=32): + self.loss_list = [] + self.topk = topk + + def submit(self, name: str, layerid: int, gptq, error: float): + + item = (name, layerid, {'gptq': gptq, 'error': error}) + + if len(self.loss_list) < self.topk: + self.loss_list.append(item) + return + + min_error = error + min_idx = -1 + for idx, data in enumerate(self.loss_list): + if min_error > data[2]['error']: + min_idx = idx + min_error = data[2]['error'] + + if min_idx >= 0: + self.loss_list[min_idx] = item + + def print(self): + self.loss_list = sorted(self.loss_list, key=lambda s: s[2]['error'], reverse=True) + + table = Texttable() + + table.header(['name', 'error']) + table.set_cols_dtype(['t', 'f']) + + for item in self.loss_list: + table.add_row([f"{item[0]}.{item[1]}", item[2]['error']]) + print(table.draw()) + print('\n') + + def items(self): + return self.loss_list + + +class GPTQ: + + def __init__(self, layer, observe=False): + self.layer = layer + self.dev = self.layer.weight.device + W = layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.rows = W.shape[0] + self.columns = W.shape[1] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + self.quantizer = Quantizer() + self.observe = observe + + def add_batch(self, inp, out): + # Hessian H = 2 X XT + λ I + if self.observe: + self.inp1 = inp + self.out1 = out + else: + self.inp1 = None + self.out1 = None + + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + if isinstance(self.layer, nn.Conv2d): + unfold = nn.Unfold(self.layer.kernel_size, dilation=self.layer.dilation, padding=self.layer.padding, stride=self.layer.stride) + inp = unfold(inp) + inp = inp.permute([1, 0, 2]) + inp = inp.flatten(1) + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + # inp = inp.float() + inp = math.sqrt(2 / self.nsamples) * inp.float() + # self.H += 2 / self.nsamples * inp.matmul(inp.t()) + self.H += inp.matmul(inp.t()) + + def print_loss(self, name, q_weight, weight_error, timecost): + table = Texttable() + name += ' ' * (16 - len(name)) + + table.header(['name', 'weight_error', 'fp_inp_SNR', 'q_inp_SNR', 'time']) + + # assign weight + self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + + if self.inp1 is not None: + # quantize input to int8 + quantizer = Quantizer() + quantizer.configure(8, perchannel=False, sym=True, mse=False) + quantizer.find_params(self.inp1) + q_in = quantizer.quantize(self.inp1).type(torch.float16) + q_out = self.layer(q_in) + + # get kinds of SNR + q_SNR = torch_snr_error(q_out, self.out1).item() + fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item() + else: + q_SNR = '-' + fp_SNR = '-' + + table.add_row([name, weight_error, fp_SNR, q_SNR, timecost]) + print(table.draw().split('\n')[-2]) + + def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, name=''): + self.layer.to(self.dev) + + W = self.layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.float() + + tick = time.time() + + if not self.quantizer.ready(): + self.quantizer.find_params(W, weight=True) + + H = self.H + if not self.observe: + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + if actorder: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + g_idx = [] + scale = [] + zero = [] + now_idx = 1 + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if groupsize != -1: + if (i1 + i) % groupsize == 0: + self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) + + if ((i1 + i) // groupsize) - now_idx == -1: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + now_idx += 1 + + q = self.quantizer.quantize(w.unsqueeze(1)).flatten() + Q1[:, i] = q + Losses1[:, i] = (w - q)**2 / d**2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + torch.cuda.synchronize() + error = torch.sum(Losses).item() + + groupsize = groupsize if groupsize != -1 else self.columns + g_idx = [i // groupsize for i in range(self.columns)] + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) + if actorder: + invperm = torch.argsort(perm) + Q = Q[:, invperm] + g_idx = g_idx[invperm] + + if isinstance(self.layer, transformers.Conv1D): + Q = Q.t() + + self.print_loss(name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)) + + if scale == []: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + scale = torch.cat(scale, dim=1) + zero = torch.cat(zero, dim=1) + return scale, zero, g_idx, error + + def free(self): + self.inp1 = None + self.out1 = None + self.H = None + self.Losses = None + self.Trace = None + torch.cuda.empty_cache() diff --git a/colossalai/gptq/gptq_utils/quant/__init__.py b/colossalai/gptq/gptq_utils/quant/__init__.py new file mode 100644 index 000000000000..64452784656b --- /dev/null +++ b/colossalai/gptq/gptq_utils/quant/__init__.py @@ -0,0 +1,5 @@ +from .quantizer import Quantizer +from .fused_attn import QuantLlamaAttention, make_quant_attn +from .fused_mlp import QuantLlamaMLP, make_fused_mlp, autotune_warmup_fused +from .quant_linear import QuantLinear, make_quant_linear, autotune_warmup_linear +from .triton_norm import TritonLlamaRMSNorm, make_quant_norm diff --git a/colossalai/gptq/gptq_utils/quant/custom_autotune.py b/colossalai/gptq/gptq_utils/quant/custom_autotune.py new file mode 100644 index 000000000000..286cf5d08586 --- /dev/null +++ b/colossalai/gptq/gptq_utils/quant/custom_autotune.py @@ -0,0 +1,194 @@ +#https://github.com/fpgaminer/GPTQ-triton +""" +Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. +""" + +import builtins +import math +import time +from typing import Dict + +import triton + + +class Autotuner(triton.KernelInterface): + + def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False): + ''' + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. + 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results + ''' + if not configs: + self.configs = [triton.Config({}, num_warps=4, num_stages=2)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.nearest_power_of_two = nearest_power_of_two + self.cache = {} + # hook to reset all required tensor to zeros before relaunching a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + + self.hook = _hook + self.arg_names = arg_names + # prune configs + if prune_configs_by: + perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] + if 'early_config_prune' in prune_configs_by: + early_config_prune = prune_configs_by['early_config_prune'] + else: + perf_model, top_k, early_config_prune = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.early_config_prune = early_config_prune + self.fn = fn + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + + def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) + self.hook(args) + self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) + return triton.testing.do_bench(kernel_call, rep=40) + + # try: + # # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses + # # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default + # return triton.testing.do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40) + # except triton.compiler.OutOfResources: + # return (float('inf'), float('inf'), float('inf')) + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + key = tuple(args[i] for i in self.key_idx) + + # This reduces the amount of autotuning by rounding the keys to the nearest power of two + # In my testing this gives decent results, and greatly reduces the amount of tuning required + if self.nearest_power_of_two: + key = tuple([2**int(math.log2(x) + 0.5) for x in key]) + + if key not in self.cache: + # prune configs + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.hook(args) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if config.pre_hook is not None: + config.pre_hook(self.nargs) + return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + for config in self.prune_configs(kwargs): + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + self.nargs = None + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + .. highlight:: python + .. code-block:: python + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple time. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + reset the value of the provided tensor to `zero` before running any configuration. + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + """ + + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two) + + return decorator + + +def matmul248_kernel_config_pruner(configs, nargs): + """ + The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. + """ + m = max(2**int(math.ceil(math.log2(nargs['M']))), 16) + n = max(2**int(math.ceil(math.log2(nargs['N']))), 16) + k = max(2**int(math.ceil(math.log2(nargs['K']))), 16) + + used = set() + for config in configs: + block_size_m = min(m, config.kwargs['BLOCK_SIZE_M']) + block_size_n = min(n, config.kwargs['BLOCK_SIZE_N']) + block_size_k = min(k, config.kwargs['BLOCK_SIZE_K']) + group_size_m = config.kwargs['GROUP_SIZE_M'] + + if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used: + continue + + used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps)) + yield triton.Config({ + 'BLOCK_SIZE_M': block_size_m, + 'BLOCK_SIZE_N': block_size_n, + 'BLOCK_SIZE_K': block_size_k, + 'GROUP_SIZE_M': group_size_m + }, + num_stages=config.num_stages, + num_warps=config.num_warps) diff --git a/colossalai/gptq/gptq_utils/quant/fused_attn.py b/colossalai/gptq/gptq_utils/quant/fused_attn.py new file mode 100644 index 000000000000..2076c2cc37b2 --- /dev/null +++ b/colossalai/gptq/gptq_utils/quant/fused_attn.py @@ -0,0 +1,204 @@ +from torch.nn import functional as F +from transformers.models.llama.modeling_llama import LlamaAttention +from .quant_linear import * +import triton +import triton.language as tl + + +@triton.jit +def rotate_half_kernel( + qk_seq_ptr, + position_ids_ptr, + qk_seq_stride, + position_ids_batch_stride, + seq_len, + HEAD_DIM: tl.constexpr, + BLOCK_HEIGHT: tl.constexpr, + BLOCK_WIDTH: tl.constexpr, + INV_BASE: tl.constexpr +): + # qk_seq_ptr: (bsz, seq_len, 2, num_heads, head_dim) -- OK to be discontinuous in 2nd dimension. + # position ids: (bsz, seq_len) -- must be contiguous in the last dimension. + + HALF_HEAD: tl.constexpr = HEAD_DIM // 2 + STEPS_PER_ROW: tl.constexpr = HALF_HEAD // BLOCK_WIDTH + + batch_seq = tl.program_id(axis=0) + row_blk_x_col_blk = tl.program_id(axis=1) + + row_blk = row_blk_x_col_blk // STEPS_PER_ROW + row = row_blk * BLOCK_HEIGHT + if BLOCK_WIDTH < HALF_HEAD: + col_blk = row_blk_x_col_blk % STEPS_PER_ROW + col = col_blk * BLOCK_WIDTH + else: + col: tl.constexpr = 0 + + # A block will never cross a sequence boundary, which simplifies things a lot. + batch = batch_seq // seq_len + seq = batch_seq % seq_len + position_id = tl.load(position_ids_ptr + batch * position_ids_batch_stride + seq) + # As sometimes happens, just calculating this on the fly is faster than loading it from memory. + # Use `tl.libdevice.exp` rather than `tl.exp` -- the latter is less accurate. + freq = tl.libdevice.exp((col + tl.arange(0, BLOCK_WIDTH)).to(tl.float32) * INV_BASE) * position_id + cos = tl.cos(freq).to(tl.float32) + sin = tl.sin(freq).to(tl.float32) + + col_offsets: tl.constexpr = tl.arange(0, BLOCK_WIDTH) + embed_offsets = (row * HEAD_DIM + col) + col_offsets + x_ptrs = (qk_seq_ptr + batch_seq * qk_seq_stride) + embed_offsets + + for k in range(0, BLOCK_HEIGHT): + x = tl.load(x_ptrs).to(tl.float32) + y = tl.load(x_ptrs + HALF_HEAD).to(tl.float32) + out_x = x * cos - y * sin + tl.store(x_ptrs, out_x) + out_y = x * sin + y * cos + tl.store(x_ptrs + HALF_HEAD, out_y) + x_ptrs += HEAD_DIM + + +def triton_rotate_half_(qk, position_ids, config=None): + with torch.cuda.device(qk.device): + batch_size, seq_len, qandk, num_heads, head_dim = qk.shape + + # This default is the fastest for most job sizes, at least on my RTX 4090, and when it's not it's within spitting distance of the best option. There are some odd cases where having a block height of 2 or 4 helps but the difference is within 5%. It makes sense that this configuration is fast from a memory bandwidth and caching perspective. + config = config or {'BLOCK_HEIGHT': 1, 'BLOCK_WIDTH': min(128, head_dim // 2), 'num_warps': 1} + config['BLOCK_HEIGHT'] = min(config['BLOCK_HEIGHT'], 2 * num_heads) + + assert qk.stride(3) == head_dim + assert qk.stride(4) == 1 + assert position_ids.shape == (batch_size, seq_len) + assert position_ids.stride(1) == 1, 'position_ids must be contiguous in the last dimension' + assert (2 * num_heads) % config['BLOCK_HEIGHT'] == 0, f'number of rows not evenly divisible by {config["BLOCK_HEIGHT"]}' + assert (head_dim // 2) % config['BLOCK_WIDTH'] == 0, f'number of columns ({head_dim // 2}) not evenly divisible by {config["BLOCK_WIDTH"]}' + + qk_by_seq = qk.view(batch_size * seq_len, 2 * num_heads * head_dim) + grid = (qk_by_seq.shape[0], (2 * num_heads // config['BLOCK_HEIGHT']) * (head_dim // 2 // config['BLOCK_WIDTH'])) + + # Must be the same as the theta of the frequencies used to train the model. + BASE = 10000.0 + + rotate_half_kernel[grid]( + qk_by_seq, + position_ids, + qk_by_seq.stride(0), + position_ids.stride(0), + seq_len, + HEAD_DIM=head_dim, + BLOCK_HEIGHT=config['BLOCK_HEIGHT'], + BLOCK_WIDTH=config['BLOCK_WIDTH'], + INV_BASE=-2.0 * math.log(BASE) / head_dim, + num_warps=config['num_warps'] + ) + + +class QuantLlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + hidden_size, + num_heads, + qkv_proj, + o_proj + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + if (self.head_dim * num_heads) != self.hidden_size: + raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {num_heads}).") + self.qkv_proj = qkv_proj + self.o_proj = o_proj + + def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False): + """Input shape: Batch x Time x Channel""" + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.qkv_proj(hidden_states) + qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim) + + # This updates the query and key states in-place, saving VRAM. + triton_rotate_half_(qkv_states[:, :, :2], position_ids) + + query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2) + del qkv_states + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + is_causal = past_key_value is None + + kv_seq_len = q_len + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + if use_cache: + # Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor + # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this. + key_states = key_states.contiguous() + value_states = value_states.contiguous() + query_states = query_states.contiguous() + + past_key_value = (key_states, value_states) if use_cache else None + + with torch.backends.cuda.sdp_kernel(enable_math=False): + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal) + del query_states, key_states, value_states + + attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +def make_quant_attn(model): + """ + Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. + """ + + for name, m in model.named_modules(): + if not isinstance(m, LlamaAttention): + continue + + q_proj = m.q_proj + k_proj = m.k_proj + v_proj = m.v_proj + + qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) + qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) + scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) + g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) + bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None + + qkv_layer = QuantLinear(q_proj.bits, q_proj.groupsize, q_proj.infeatures, q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, True if q_proj.bias is not None else False) + qkv_layer.qweight = qweights + qkv_layer.qzeros = qzeros + qkv_layer.scales = scales + qkv_layer.g_idx = g_idx + qkv_layer.bias = bias + # We're dropping the rotary embedding layer m.rotary_emb here. We don't need it in the triton branch. + + attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj) + + if '.' in name: + parent_name = name.rsplit('.', 1)[0] + child_name = name[len(parent_name) + 1:] + parent = model.get_submodule(parent_name) + else: + parent_name = '' + parent = model + child_name = name + + #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") + + setattr(parent, child_name, attn) diff --git a/colossalai/gptq/gptq_utils/quant/fused_mlp.py b/colossalai/gptq/gptq_utils/quant/fused_mlp.py new file mode 100644 index 000000000000..a5e402e38f94 --- /dev/null +++ b/colossalai/gptq/gptq_utils/quant/fused_mlp.py @@ -0,0 +1,288 @@ +import numpy as np +import torch +import torch.nn as nn +from torch.cuda.amp import custom_bwd, custom_fwd +from transformers.models.llama.modeling_llama import LlamaMLP + +try: + import triton + import triton.language as tl + from . import custom_autotune + + # code based https://github.com/fpgaminer/GPTQ-triton + @custom_autotune.autotune( + configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 256, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), # 3090 + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 16, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), # 3090 + triton.Config({ + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=4), # 3090 + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 16, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), # 3090 + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), # 3090 + ], + key=['M', 'N', 'K'], + nearest_power_of_two=True, + prune_configs_by={ + 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, + 'perf_model': None, + 'top_k': None, + }, + ) + @triton.jit + def fusedmatmul_248_kernel(a_ptr, c_ptr, b1_ptr, scales1_ptr, zeros1_ptr, g1_ptr, b2_ptr, scales2_ptr, zeros2_ptr, g2_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, + stride_cm, stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + """ + Computes: C = silu(A * B1) * (A * B2) + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (1, N) float16 + zeros is of shape (1, N//8) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) + b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) + g1_ptrs = g1_ptr + offs_k + g2_ptrs = g2_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales1_ptrs = scales1_ptr + offs_bn[None, :] + scales2_ptrs = scales2_ptr + offs_bn[None, :] + zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits) + zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, num_pid_k): + g1_idx = tl.load(g1_ptrs) + g2_idx = tl.load(g2_ptrs) + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales) + + zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq + zeros1 = (zeros1 + 1) + + zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq + zeros2 = (zeros2 + 1) + + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + b2 = tl.load(b2_ptrs) + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values + b1 = (b1 - zeros1) * scales1 # Scale and shift + accumulator1 += tl.dot(a, b1) + + b2 = (b2 >> shifter[:, None]) & maxq + b2 = (b2 - zeros2) * scales2 + accumulator2 += tl.dot(a, b2) + + a_ptrs += BLOCK_SIZE_K + b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g1_ptrs += BLOCK_SIZE_K + g2_ptrs += BLOCK_SIZE_K + + accumulator1 = silu(accumulator1) + c = accumulator1 * accumulator2 + c = c.to(tl.float16) + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + @triton.jit + def silu(x): + return x * tl.sigmoid(x) +except: + print('triton not installed.') + + +class QuantLlamaMLP(nn.Module): + + def __init__( + self, + gate_proj, + down_proj, + up_proj, + ): + super().__init__() + self.register_buffer('gate_proj_qweight', gate_proj.qweight) + self.register_buffer('gate_proj_scales', gate_proj.scales) + self.register_buffer('gate_proj_qzeros', gate_proj.qzeros) + self.register_buffer('gate_proj_g_idx', gate_proj.g_idx) + self.register_buffer('up_proj_qweight', up_proj.qweight) + self.register_buffer('up_proj_scales', up_proj.scales) + self.register_buffer('up_proj_qzeros', up_proj.qzeros) + self.register_buffer('up_proj_g_idx', up_proj.g_idx) + + self.infeatures = gate_proj.infeatures + self.intermediate_size = gate_proj.outfeatures + self.outfeatures = down_proj.outfeatures + self.bits = gate_proj.bits + self.maxq = gate_proj.maxq + + self.down_proj = down_proj + + def forward(self, x): + return self.down_proj(self.triton_llama_mlp(x)) + + def triton_llama_mlp(self, x): + with torch.cuda.device(x.device): + out_shape = x.shape[:-1] + (self.intermediate_size, ) + x = x.reshape(-1, x.shape[-1]) + M, K = x.shape + N = self.intermediate_size + c = torch.empty((M, N), device=x.device, dtype=torch.float16) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + fusedmatmul_248_kernel[grid](x, c, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, self.gate_proj_g_idx, self.up_proj_qweight, self.up_proj_scales, + self.up_proj_qzeros, self.up_proj_g_idx, M, N, K, self.bits, self.maxq, x.stride(0), x.stride(1), self.gate_proj_qweight.stride(0), + self.gate_proj_qweight.stride(1), c.stride(0), c.stride(1), self.gate_proj_scales.stride(0), self.gate_proj_qzeros.stride(0)) + c = c.reshape(out_shape) + return c + + def fused2cuda(self): + self.gate_proj_qweight = self.gate_proj_qweight.cuda() + self.gate_proj_scales = self.gate_proj_scales.cuda() + self.gate_proj_qzeros = self.gate_proj_qzeros.cuda() + self.gate_proj_g_idx = self.gate_proj_g_idx.cuda() + self.up_proj_qweight = self.up_proj_qweight.cuda() + self.up_proj_scales = self.up_proj_scales.cuda() + self.up_proj_qzeros = self.up_proj_qzeros.cuda() + self.up_proj_g_idx = self.up_proj_g_idx.cuda() + + def fused2cpu(self): + self.gate_proj_qweight = self.gate_proj_qweight.cpu() + self.gate_proj_scales = self.gate_proj_scales.cpu() + self.gate_proj_qzeros = self.gate_proj_qzeros.cpu() + self.gate_proj_g_idx = self.gate_proj_g_idx.cpu() + self.up_proj_qweight = self.up_proj_qweight.cpu() + self.up_proj_scales = self.up_proj_scales.cpu() + self.up_proj_qzeros = self.up_proj_qzeros.cpu() + self.up_proj_g_idx = self.up_proj_g_idx.cpu() + + +def make_fused_mlp(m, parent_name=''): + """ + Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations. + """ + if isinstance(m, LlamaMLP): + return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj) + + for name, child in m.named_children(): + child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}") + + if isinstance(child, QuantLlamaMLP): + setattr(m, name, child) + return m + + +def autotune_warmup_fused(model): + """ + Pre-tunes the quantized kernel + """ + from tqdm import tqdm + + kn_values = {} + + for _, m in model.named_modules(): + if not isinstance(m, QuantLlamaMLP): + continue + + k = m.infeatures + n = m.intermediate_size + + m.fused2cuda() + if (k, n) not in kn_values: + kn_values[(k, n)] = m + + print(f'Found {len(kn_values)} unique fused mlp KN values.') + + print('Warming up autotune cache ...') + with torch.no_grad(): + for m in tqdm(range(0, 12)): + m = 2**m # [1, 2048] + for (k, n), (modules) in kn_values.items(): + a = torch.randn(m, k, dtype=torch.float16, device='cuda') + modules.triton_llama_mlp(a) + + for (k, n), (modules) in kn_values.items(): + a = torch.randn(m, k, dtype=torch.float16, device='cuda') + modules.fused2cpu() + del kn_values diff --git a/colossalai/gptq/gptq_utils/quant/quant_linear.py b/colossalai/gptq/gptq_utils/quant/quant_linear.py new file mode 100644 index 000000000000..5144a962a928 --- /dev/null +++ b/colossalai/gptq/gptq_utils/quant/quant_linear.py @@ -0,0 +1,422 @@ +import math +import numpy as np +import torch +import torch.nn as nn +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import triton + import triton.language as tl + from . import custom_autotune + + # code based https://github.com/fpgaminer/GPTQ-triton + @custom_autotune.autotune( + configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=3, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], + nearest_power_of_two=True, + prune_configs_by={ + 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, + 'perf_model': None, + 'top_k': None, + }, + ) + @triton.jit + def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_bn[None, :] + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, num_pid_k): + g_idx = tl.load(g_ptrs) + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @custom_autotune.autotune(configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 256, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=3, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], + nearest_power_of_two=True) + @triton.jit + def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, + stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, N) float16 + B is of shape (K//8, N) int32 + C is of shape (M, K) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_k + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_k = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_n = tl.arange(0, BLOCK_SIZE_N) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_bk + g_idx = tl.load(g_ptrs) + + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales + zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros + + shifter = (offs_bk % infearure_per_bits) * bits + zeros_shifter = (offs_n % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + + for n in range(0, num_pid_n): + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + b = tl.trans(b) + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_N + b_ptrs += BLOCK_SIZE_N + scales_ptrs += BLOCK_SIZE_N + zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits) + + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) + tl.store(c_ptrs, accumulator, mask=c_mask) +except: + print('triton not installed.') + + +def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): + with torch.cuda.device(input.device): + output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), ) + matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), + qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) + return output + + +def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): + with torch.cuda.device(input.device): + output_dim = (qweight.shape[0] * 32) // bits + output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16) + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), ) + transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), + qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) + return output + + +class QuantLinearFunction(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): + output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) + ctx.save_for_backward(qweight, scales, qzeros, g_idx) + ctx.bits, ctx.maxq = bits, maxq + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + qweight, scales, qzeros, g_idx = ctx.saved_tensors + bits, maxq = ctx.bits, ctx.maxq + grad_input = None + + if ctx.needs_input_grad[0]: + grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq) + return grad_input, None, None, None, None, None, None + + +class QuantLinear(nn.Module): + + def __init__(self, bits, groupsize, infeatures, outfeatures, bias): + super().__init__() + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize if groupsize != -1 else infeatures + + self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) + self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) + if bias: + self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) + else: + self.bias = None + + def pack(self, linear, scales, zeros, g_idx=None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:, None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures, ) + out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) + +def make_quant_linear(module, names, bits, groupsize, name=''): + if isinstance(module, QuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + delattr(module, attr) + setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None)) + for name1, child in module.named_children(): + make_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) + + +def autotune_warmup_linear(model, transpose=False): + """ + Pre-tunes the quantized kernel + """ + from tqdm import tqdm + + kn_values = {} + + for _, m in model.named_modules(): + if not isinstance(m, QuantLinear): + continue + + k = m.infeatures + n = m.outfeatures + + if (k, n) not in kn_values: + kn_values[(k, n)] = (m.qweight.cuda(), m.scales.cuda(), m.qzeros.cuda(), m.g_idx.cuda(), m.bits, m.maxq) + + print(f'Found {len(kn_values)} unique KN Linear values.') + + print('Warming up autotune cache ...') + with torch.no_grad(): + for m in tqdm(range(0, 12)): + m = 2**m # [1, 2048] + for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items(): + a = torch.randn(m, k, dtype=torch.float16, device='cuda') + matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq) + if transpose: + a = torch.randn(m, n, dtype=torch.float16, device='cuda') + transpose_matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq) + del kn_values diff --git a/colossalai/gptq/gptq_utils/quant/quantizer.py b/colossalai/gptq/gptq_utils/quant/quantizer.py new file mode 100644 index 000000000000..76844b8769aa --- /dev/null +++ b/colossalai/gptq/gptq_utils/quant/quantizer.py @@ -0,0 +1,127 @@ +import numpy as np +import torch +import torch.nn as nn +import math + + +class Quantizer(nn.Module): + + def __init__(self, shape=1): + super(Quantizer, self).__init__() + self.register_buffer('maxq', torch.tensor(0)) + self.register_buffer('scale', torch.zeros(shape)) + self.register_buffer('zero', torch.zeros(shape)) + + def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False): + + self.maxq = torch.tensor(2**bits - 1) + self.perchannel = perchannel + self.sym = sym + self.mse = mse + self.norm = norm + self.grid = grid + self.maxshrink = maxshrink + if trits: + self.maxq = torch.tensor(-1) + self.scale = torch.zeros_like(self.scale) + + def _quantize(self, x, scale, zero, maxq): + if maxq < 0: + return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + + def find_params(self, x, weight=False): + dev = x.device + self.maxq = self.maxq.to(dev) + + shape = x.shape + if self.perchannel: + if weight: + x = x.flatten(1) + else: + if len(shape) == 4: + x = x.permute([1, 0, 2, 3]) + x = x.flatten(1) + if len(shape) == 3: + x = x.reshape((-1, shape[-1])).t() + if len(shape) == 2: + x = x.t() + else: + x = x.flatten().unsqueeze(0) + + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + if self.maxq < 0: + self.scale = xmax + self.zero = xmin + else: + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) + else: + self.zero = torch.round(-xmin / self.scale) + + if self.mse: + best = torch.full([x.shape[0]], float('inf'), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) + q -= x + q.abs_() + q.pow_(self.norm) + err = torch.sum(q, 1) + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + if not self.perchannel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + self.scale = self.scale.repeat(tmp) + self.zero = self.zero.repeat(tmp) + + if weight: + shape = [-1] + [1] * (len(shape) - 1) + self.scale = self.scale.reshape(shape) + self.zero = self.zero.reshape(shape) + return + if len(shape) == 4: + self.scale = self.scale.reshape((1, -1, 1, 1)) + self.zero = self.zero.reshape((1, -1, 1, 1)) + if len(shape) == 3: + self.scale = self.scale.reshape((1, 1, -1)) + self.zero = self.zero.reshape((1, 1, -1)) + if len(shape) == 2: + self.scale = self.scale.unsqueeze(0) + self.zero = self.zero.unsqueeze(0) + + def quantize(self, x): + if self.ready(): + return self._quantize(x, self.scale, self.zero, self.maxq) + + return x + + def enabled(self): + return self.maxq > 0 + + def ready(self): + return torch.all(self.scale != 0) diff --git a/colossalai/gptq/gptq_utils/quant/triton_norm.py b/colossalai/gptq/gptq_utils/quant/triton_norm.py new file mode 100644 index 000000000000..1e3228a18d51 --- /dev/null +++ b/colossalai/gptq/gptq_utils/quant/triton_norm.py @@ -0,0 +1,92 @@ +import torch +from torch import nn +import triton +import triton.language as tl +from transformers.models.llama.modeling_llama import LlamaRMSNorm + +@triton.jit +def rms_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.where(cols < N, x, 0.) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x_hat = x * rstd + y = x_hat * w + # Write output + tl.store(Y + cols, y, mask=mask) + +class TritonLlamaRMSNorm(nn.Module): + def __init__(self, weight, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = weight + self.variance_epsilon = eps + + def forward(self, x): + with torch.cuda.device(x.device): + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # enqueue kernel + rms_norm_fwd_fused[(M,)](x_arg, y, self.weight, + x_arg.stride(0), N, self.variance_epsilon, + BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) + return y + + +def make_quant_norm(model): + """ + Replace all LlamaRMSNorm modules with TritonLlamaRMSNorm modules + """ + + for name, m in model.named_modules(): + if not isinstance(m, LlamaRMSNorm): + continue + + norm = TritonLlamaRMSNorm(m.weight, m.variance_epsilon) + + if '.' in name: + parent_name = name.rsplit('.', 1)[0] + child_name = name[len(parent_name) + 1:] + parent = model.get_submodule(parent_name) + else: + parent_name = '' + parent = model + child_name = name + + #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") + + setattr(parent, child_name, norm) diff --git a/colossalai/gptq/gptq_utils/utils/__init__.py b/colossalai/gptq/gptq_utils/utils/__init__.py new file mode 100644 index 000000000000..cf1741216f79 --- /dev/null +++ b/colossalai/gptq/gptq_utils/utils/__init__.py @@ -0,0 +1,3 @@ +from .modelutils import DEV, find_layers, gen_conditions, torch_snr_error +from .datautils import set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders +from .export import export_quant_table diff --git a/colossalai/gptq/gptq_utils/utils/datautils.py b/colossalai/gptq/gptq_utils/utils/datautils.py new file mode 100644 index 000000000000..10a3a43d3ef5 --- /dev/null +++ b/colossalai/gptq/gptq_utils/utils/datautils.py @@ -0,0 +1,193 @@ +import numpy as np +import torch + + +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + + +def get_wikitext2(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') + testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + + from transformers import AutoTokenizer + try: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + except: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) + trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') + testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_ptb(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') + valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') + + from transformers import AutoTokenizer + try: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + except: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) + trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') + testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_c4(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False) + valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', use_auth_token=False) + + from transformers import AutoTokenizer + try: + if "llama" in model: + from transformers import LlamaTokenizer + tokenizer = LlamaTokenizer.from_pretrained(model, use_fast=False) + else: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + except: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + import random + random.seed(0) + valenc = [] + for _ in range(256): + while True: + i = random.randint(0, len(valdata) - 1) + tmp = tokenizer(valdata[i]['text'], return_tensors='pt') + if tmp.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + valenc.append(tmp.input_ids[:, i:j]) + valenc = torch.hstack(valenc) + + class TokenizerWrapper: + + def __init__(self, input_ids): + self.input_ids = input_ids + + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + + +def get_ptb_new(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') + testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') + + from transformers import AutoTokenizer + try: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + except: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) + trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') + testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_c4_new(nsamples, seed, seqlen, model): + from datasets import load_dataset + traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') + valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') + + from transformers import AutoTokenizer + try: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + except: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) + + import random + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') + valenc = valenc.input_ids[:, :(256 * seqlen)] + + class TokenizerWrapper: + + def __init__(self, input_ids): + self.input_ids = input_ids + + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + + +def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''): + if 'wikitext2' in name: + return get_wikitext2(nsamples, seed, seqlen, model) + if 'ptb' in name: + if 'new' in name: + return get_ptb_new(nsamples, seed, seqlen, model) + return get_ptb(nsamples, seed, seqlen, model) + if 'c4' in name: + if 'new' in name: + return get_c4_new(nsamples, seed, seqlen, model) + return get_c4(nsamples, seed, seqlen, model) diff --git a/colossalai/gptq/gptq_utils/utils/export.py b/colossalai/gptq/gptq_utils/utils/export.py new file mode 100644 index 000000000000..a623afcf49b5 --- /dev/null +++ b/colossalai/gptq/gptq_utils/utils/export.py @@ -0,0 +1,37 @@ +import numpy as np +import toml +import os + + +def export_quant_table(quantizers: dict, quant_dir: str, format: str = 'toml'): + + table = {} + + def save_tensor(name: str, tensor): + np.save(os.path.join(quant_dir, name), tensor.numpy()) + return '{}.npy'.format(name) + + for key, value in quantizers.items(): + quantizer = value[0] + + dump = dict() + + sym = quantizer.sym + if not sym: + dump['zero'] = save_tensor(name=key + '.zero', tensor=value[2]) + dump['scale'] = save_tensor(name=key + '.scale', tensor=value[1]) + dump['wbits'] = value[4] + dump['groupsize'] = value[5] + if value[5] > 0: + dump['group_ids'] = save_tensor(name=key + '.group_ids', tensor=value[3]) + + dump['sym'] = sym + dump['perchannel'] = quantizer.perchannel + + table[key] = dump + + if not os.path.exists(quant_dir): + os.mkdir(quant_dir) + + with open(os.path.join(quant_dir, 'quant.toml'), 'w') as f: + toml.dump(table, f) diff --git a/colossalai/gptq/gptq_utils/utils/modelutils.py b/colossalai/gptq/gptq_utils/utils/modelutils.py new file mode 100644 index 000000000000..d043cca02b7d --- /dev/null +++ b/colossalai/gptq/gptq_utils/utils/modelutils.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn + +DEV = torch.device('cuda:0') + + +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1)) + return res + + +def gen_conditions(_wbits, _groupsize): + wbits = _wbits + groupsize = _groupsize + conditions = [] + while True: + if wbits >= 8: + if groupsize == -1 or groupsize == 32: + break + + if groupsize > 32: + groupsize /= 2 + else: + wbits *= 2 + groupsize = _groupsize + + conditions.append((int(wbits), int(groupsize))) + return conditions + + +# copy from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py +def torch_snr_error(y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = 'mean') -> torch.Tensor: + """ + Compute SNR between y_pred(tensor) and y_real(tensor) + + SNR can be calcualted as following equation: + + SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2 + + if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements. + + SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2) + Args: + y_pred (torch.Tensor): _description_ + y_real (torch.Tensor): _description_ + reduction (str, optional): _description_. Defaults to 'mean'. + Raises: + ValueError: _description_ + ValueError: _description_ + Returns: + torch.Tensor: _description_ + """ + y_pred = y_pred.type(torch.float32) + y_real = y_real.type(torch.float32) + + if y_pred.shape != y_real.shape: + raise ValueError(f'Can not compute snr loss for tensors with different shape. ' + f'({y_pred.shape} and {y_real.shape})') + reduction = str(reduction).lower() + + if y_pred.ndim == 1: + y_pred = y_pred.unsqueeze(0) + y_real = y_real.unsqueeze(0) + + y_pred = y_pred.flatten(start_dim=1) + y_real = y_real.flatten(start_dim=1) + + noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1) + signal_power = torch.pow(y_real, 2).sum(dim=-1) + snr = (noise_power) / (signal_power + 1e-7) + + if reduction == 'mean': + return torch.mean(snr) + elif reduction == 'sum': + return torch.sum(snr) + elif reduction == 'none': + return snr + else: + raise ValueError(f'Unsupported reduction method.') diff --git a/colossalai/gptq/inference_builder.py b/colossalai/gptq/inference_builder.py new file mode 100644 index 000000000000..60cb7167a76e --- /dev/null +++ b/colossalai/gptq/inference_builder.py @@ -0,0 +1,761 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import sys +import time +import importlib +from pathlib import Path +import subprocess +import shlex +import shutil +import tempfile +import distutils.ccompiler +import distutils.log +import distutils.sysconfig +from distutils.errors import CompileError, LinkError +from abc import ABC, abstractmethod +from typing import List + +YELLOW = '\033[93m' +END = '\033[0m' +WARNING = f"{YELLOW} [WARNING] {END}" + +DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions" +DEFAULT_COMPUTE_CAPABILITIES = "6.0;6.1;7.0" + +try: + import torch +except ImportError: + print(f"{WARNING} unable to import torch, please install it if you want to pre-compile any deepspeed ops.") +else: + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + + +def installed_cuda_version(name=""): + import torch.utils.cpp_extension + cuda_home = torch.utils.cpp_extension.CUDA_HOME + assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)" + # Ensure there is not a cuda version mismatch between torch and nvcc compiler + output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True) + output_split = output.split() + release_idx = output_split.index("release") + release = output_split[release_idx + 1].replace(',', '').split(".") + # Ignore patch versions, only look at major + minor + cuda_major, cuda_minor = release[:2] + return int(cuda_major), int(cuda_minor) + + +def get_default_compute_capabilities(): + compute_caps = DEFAULT_COMPUTE_CAPABILITIES + import torch.utils.cpp_extension + if torch.utils.cpp_extension.CUDA_HOME is not None and installed_cuda_version()[0] >= 11: + if installed_cuda_version()[0] == 11 and installed_cuda_version()[1] == 0: + # Special treatment of CUDA 11.0 because compute_86 is not supported. + compute_caps += ";8.0" + else: + compute_caps += ";8.0;8.6" + return compute_caps + + +# list compatible minor CUDA versions - so that for example pytorch built with cuda-11.0 can be used +# to build deepspeed and system-wide installed cuda 11.2 +cuda_minor_mismatch_ok = { + 10: [ + "10.0", + "10.1", + "10.2", + ], + 11: ["11.0", "11.1", "11.2", "11.3", "11.4", "11.5", "11.6", "11.7", "11.8"], +} + + +def assert_no_cuda_mismatch(name=""): + cuda_major, cuda_minor = installed_cuda_version(name) + sys_cuda_version = f'{cuda_major}.{cuda_minor}' + torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) + # This is a show-stopping error, should probably not proceed past this + if sys_cuda_version != torch_cuda_version: + if (cuda_major in cuda_minor_mismatch_ok and sys_cuda_version in cuda_minor_mismatch_ok[cuda_major] + and torch_cuda_version in cuda_minor_mismatch_ok[cuda_major]): + print(f"Installed CUDA version {sys_cuda_version} does not match the " + f"version torch was compiled with {torch.version.cuda} " + "but since the APIs are compatible, accepting this combination") + return True + raise Exception(f">- DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the " + f"version torch was compiled with {torch.version.cuda}, unable to compile " + "cuda/cpp extensions without a matching cuda version.") + return True + + +class OpBuilder(ABC): + _rocm_version = None + _is_rocm_pytorch = None + + def __init__(self, name): + self.name = name + self.jit_mode = False + self.build_for_cpu = False + self.error_log = None + + @abstractmethod + def absolute_name(self): + ''' + Returns absolute build path for cases where the op is pre-installed, e.g., deepspeed.ops.adam.cpu_adam + will be installed as something like: deepspeed/ops/adam/cpu_adam.so + ''' + pass + + @abstractmethod + def sources(self): + ''' + Returns list of source files for your op, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed) + ''' + pass + + def hipify_extension(self): + pass + + @staticmethod + def validate_torch_version(torch_info): + install_torch_version = torch_info['version'] + current_torch_version = ".".join(torch.__version__.split('.')[:2]) + if install_torch_version != current_torch_version: + raise RuntimeError("PyTorch version mismatch! DeepSpeed ops were compiled and installed " + "with a different version than what is being used at runtime. " + f"Please re-install DeepSpeed or switch torch versions. " + f"Install torch version={install_torch_version}, " + f"Runtime torch version={current_torch_version}") + + @staticmethod + def validate_torch_op_version(torch_info): + if not OpBuilder.is_rocm_pytorch(): + current_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) + install_cuda_version = torch_info['cuda_version'] + if install_cuda_version != current_cuda_version: + raise RuntimeError("CUDA version mismatch! DeepSpeed ops were compiled and installed " + "with a different version than what is being used at runtime. " + f"Please re-install DeepSpeed or switch torch versions. " + f"Install CUDA version={install_cuda_version}, " + f"Runtime CUDA version={current_cuda_version}") + else: + current_hip_version = ".".join(torch.version.hip.split('.')[:2]) + install_hip_version = torch_info['hip_version'] + if install_hip_version != current_hip_version: + raise RuntimeError("HIP version mismatch! DeepSpeed ops were compiled and installed " + "with a different version than what is being used at runtime. " + f"Please re-install DeepSpeed or switch torch versions. " + f"Install HIP version={install_hip_version}, " + f"Runtime HIP version={current_hip_version}") + + @staticmethod + def is_rocm_pytorch(): + if OpBuilder._is_rocm_pytorch is not None: + return OpBuilder._is_rocm_pytorch + + _is_rocm_pytorch = False + try: + import torch + except ImportError: + pass + else: + if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5): + _is_rocm_pytorch = hasattr(torch.version, 'hip') and torch.version.hip is not None + if _is_rocm_pytorch: + from torch.utils.cpp_extension import ROCM_HOME + _is_rocm_pytorch = ROCM_HOME is not None + OpBuilder._is_rocm_pytorch = _is_rocm_pytorch + return OpBuilder._is_rocm_pytorch + + @staticmethod + def installed_rocm_version(): + if OpBuilder._rocm_version: + return OpBuilder._rocm_version + + ROCM_MAJOR = '0' + ROCM_MINOR = '0' + if OpBuilder.is_rocm_pytorch(): + from torch.utils.cpp_extension import ROCM_HOME + rocm_ver_file = Path(ROCM_HOME).joinpath(".info/version-dev") + if rocm_ver_file.is_file(): + with open(rocm_ver_file, 'r') as file: + ROCM_VERSION_DEV_RAW = file.read() + elif "rocm" in torch.__version__: + ROCM_VERSION_DEV_RAW = torch.__version__.split("rocm")[1] + else: + assert False, "Could not detect ROCm version" + assert ROCM_VERSION_DEV_RAW != "", "Could not detect ROCm version" + ROCM_MAJOR = ROCM_VERSION_DEV_RAW.split('.')[0] + ROCM_MINOR = ROCM_VERSION_DEV_RAW.split('.')[1] + OpBuilder._rocm_version = (int(ROCM_MAJOR), int(ROCM_MINOR)) + return OpBuilder._rocm_version + + def include_paths(self): + ''' + Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed) + ''' + return [] + + def nvcc_args(self): + ''' + Returns optional list of compiler flags to forward to nvcc when building CUDA sources + ''' + return [] + + def cxx_args(self): + ''' + Returns optional list of compiler flags to forward to the build + ''' + return [] + + def is_compatible(self, verbose=True): + ''' + Check if all non-python dependencies are satisfied to build this op + ''' + return True + + def extra_ldflags(self): + return [] + + def libraries_installed(self, libraries): + valid = False + check_cmd = 'dpkg -l' + for lib in libraries: + result = subprocess.Popen(f'dpkg -l {lib}', stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) + valid = valid or result.wait() == 0 + return valid + + def has_function(self, funcname, libraries, verbose=False): + ''' + Test for existence of a function within a tuple of libraries. + + This is used as a smoke test to check whether a certain library is available. + As a test, this creates a simple C program that calls the specified function, + and then distutils is used to compile that program and link it with the specified libraries. + Returns True if both the compile and link are successful, False otherwise. + ''' + tempdir = None # we create a temporary directory to hold various files + filestderr = None # handle to open file to which we redirect stderr + oldstderr = None # file descriptor for stderr + try: + # Echo compile and link commands that are used. + if verbose: + distutils.log.set_verbosity(1) + + # Create a compiler object. + compiler = distutils.ccompiler.new_compiler(verbose=verbose) + + # Configure compiler and linker to build according to Python install. + distutils.sysconfig.customize_compiler(compiler) + + # Create a temporary directory to hold test files. + tempdir = tempfile.mkdtemp() + + # Define a simple C program that calls the function in question + prog = "void %s(void); int main(int argc, char** argv) { %s(); return 0; }" % (funcname, funcname) + + # Write the test program to a file. + filename = os.path.join(tempdir, 'test.c') + with open(filename, 'w') as f: + f.write(prog) + + # Redirect stderr file descriptor to a file to silence compile/link warnings. + if not verbose: + filestderr = open(os.path.join(tempdir, 'stderr.txt'), 'w') + oldstderr = os.dup(sys.stderr.fileno()) + os.dup2(filestderr.fileno(), sys.stderr.fileno()) + + # Workaround for behavior in distutils.ccompiler.CCompiler.object_filenames() + # Otherwise, a local directory will be used instead of tempdir + drive, driveless_filename = os.path.splitdrive(filename) + root_dir = driveless_filename[0] if os.path.isabs(driveless_filename) else '' + output_dir = os.path.join(drive, root_dir) + + # Attempt to compile the C program into an object file. + cflags = shlex.split(os.environ.get('CFLAGS', "")) + objs = compiler.compile([filename], output_dir=output_dir, extra_preargs=self.strip_empty_entries(cflags)) + + # Attempt to link the object file into an executable. + # Be sure to tack on any libraries that have been specified. + ldflags = shlex.split(os.environ.get('LDFLAGS', "")) + compiler.link_executable(objs, + os.path.join(tempdir, 'a.out'), + extra_preargs=self.strip_empty_entries(ldflags), + libraries=libraries) + + # Compile and link succeeded + return True + + except CompileError: + return False + + except LinkError: + return False + + except: + return False + + finally: + # Restore stderr file descriptor and close the stderr redirect file. + if oldstderr is not None: + os.dup2(oldstderr, sys.stderr.fileno()) + if filestderr is not None: + filestderr.close() + + # Delete the temporary directory holding the test program and stderr files. + if tempdir is not None: + shutil.rmtree(tempdir) + + def strip_empty_entries(self, args): + ''' + Drop any empty strings from the list of compile and link flags + ''' + return [x for x in args if len(x) > 0] + + def cpu_arch(self): + try: + from cpuinfo import get_cpu_info + except ImportError as e: + cpu_info = self._backup_cpuinfo() + if cpu_info is None: + return "-march=native" + + try: + cpu_info = get_cpu_info() + except Exception as e: + self.warning(f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), " + "falling back to `lscpu` to get this information.") + cpu_info = self._backup_cpuinfo() + if cpu_info is None: + return "-march=native" + + if cpu_info['arch'].startswith('PPC_'): + # gcc does not provide -march on PowerPC, use -mcpu instead + return '-mcpu=native' + return '-march=native' + + def is_cuda_enable(self): + try: + assert_no_cuda_mismatch(self.name) + return '-D__ENABLE_CUDA__' + except BaseException: + print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, " + "only cpu ops can be compiled!") + return '-D__DISABLE_CUDA__' + return '-D__DISABLE_CUDA__' + + def _backup_cpuinfo(self): + # Construct cpu_info dict from lscpu that is similar to what py-cpuinfo provides + if not self.command_exists('lscpu'): + self.warning(f"{self.name} attempted to query 'lscpu' after failing to use py-cpuinfo " + "to detect the CPU architecture. 'lscpu' does not appear to exist on " + "your system, will fall back to use -march=native and non-vectorized execution.") + return None + result = subprocess.check_output('lscpu', shell=True) + result = result.decode('utf-8').strip().lower() + + cpu_info = {} + cpu_info['arch'] = None + cpu_info['flags'] = "" + if 'genuineintel' in result or 'authenticamd' in result: + cpu_info['arch'] = 'X86_64' + if 'avx512' in result: + cpu_info['flags'] += 'avx512,' + elif 'avx512f' in result: + cpu_info['flags'] += 'avx512f,' + if 'avx2' in result: + cpu_info['flags'] += 'avx2' + elif 'ppc64le' in result: + cpu_info['arch'] = "PPC_" + + return cpu_info + + def simd_width(self): + try: + from cpuinfo import get_cpu_info + except ImportError as e: + cpu_info = self._backup_cpuinfo() + if cpu_info is None: + return '-D__SCALAR__' + + try: + cpu_info = get_cpu_info() + except Exception as e: + self.warning(f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), " + "falling back to `lscpu` to get this information.") + cpu_info = self._backup_cpuinfo() + if cpu_info is None: + return '-D__SCALAR__' + + if cpu_info['arch'] == 'X86_64': + if 'avx512' in cpu_info['flags'] or 'avx512f' in cpu_info['flags']: + return '-D__AVX512__' + elif 'avx2' in cpu_info['flags']: + return '-D__AVX256__' + return '-D__SCALAR__' + + def command_exists(self, cmd): + if '|' in cmd: + cmds = cmd.split("|") + else: + cmds = [cmd] + valid = False + for cmd in cmds: + result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True) + valid = valid or result.wait() == 0 + + if not valid and len(cmds) > 1: + print(f"{WARNING} {self.name} requires one of the following commands '{cmds}', but it does not exist!") + elif not valid and len(cmds) == 1: + print(f"{WARNING} {self.name} requires the '{cmd}' command, but it does not exist!") + return valid + + def warning(self, msg): + self.error_log = f"{msg}" + print(f"{WARNING} {msg}") + + def deepspeed_src_path(self, code_path): + if os.path.isabs(code_path): + return code_path + else: + return os.path.join(Path(__file__).parent.parent.absolute(), code_path) + + def builder(self): + from torch.utils.cpp_extension import CppExtension + return CppExtension(name=self.absolute_name(), + sources=self.strip_empty_entries(self.sources()), + include_dirs=self.strip_empty_entries(self.include_paths()), + extra_compile_args={'cxx': self.strip_empty_entries(self.cxx_args())}, + extra_link_args=self.strip_empty_entries(self.extra_ldflags())) + + def load(self, verbose=True): + return self.jit_load(verbose) + + def jit_load(self, verbose=True): + if not self.is_compatible(verbose): + raise RuntimeError( + f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue. {self.error_log}" + ) + try: + import ninja # noqa: F401 + except ImportError: + raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.") + + if isinstance(self, CUDAOpBuilder) and not self.is_rocm_pytorch(): + try: + assert_no_cuda_mismatch(self.name) + self.build_for_cpu = False + except BaseException: + self.build_for_cpu = True + + self.jit_mode = True + from torch.utils.cpp_extension import load + + start_build = time.time() + sources = [self.deepspeed_src_path(path) for path in self.sources()] + extra_include_paths = [self.deepspeed_src_path(path) for path in self.include_paths()] + + # Torch will try and apply whatever CCs are in the arch list at compile time, + # we have already set the intended targets ourselves we know that will be + # needed at runtime. This prevents CC collisions such as multiple __half + # implementations. Stash arch list to reset after build. + torch_arch_list = None + if "TORCH_CUDA_ARCH_LIST" in os.environ: + torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST") + os.environ["TORCH_CUDA_ARCH_LIST"] = "" + + op_module = load(name=self.name, + sources=self.strip_empty_entries(sources), + extra_include_paths=self.strip_empty_entries(extra_include_paths), + extra_cflags=self.strip_empty_entries(self.cxx_args()), + extra_cuda_cflags=self.strip_empty_entries(self.nvcc_args()), + extra_ldflags=self.strip_empty_entries(self.extra_ldflags()), + verbose=verbose) + + build_duration = time.time() - start_build + if verbose: + print(f"Time to load {self.name} op: {build_duration} seconds") + + # Reset arch list so we are not silently removing it for other possible use cases + if torch_arch_list: + os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list + + return op_module + + +class CUDAOpBuilder(OpBuilder): + + def compute_capability_args(self, cross_compile_archs=None): + """ + Returns nvcc compute capability compile flags. + + 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`. + 2. If neither is set default compute capabilities will be used + 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX + + Format: + + - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples: + + TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ... + TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ... + + - `cross_compile_archs` uses ; separator. + + """ + ccs = [] + if self.jit_mode: + # Compile for underlying architectures since we know those at runtime + for i in range(torch.cuda.device_count()): + CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i) + cc = f"{CC_MAJOR}.{CC_MINOR}" + if cc not in ccs: + ccs.append(cc) + ccs = sorted(ccs) + ccs[-1] += '+PTX' + else: + # Cross-compile mode, compile for various architectures + # env override takes priority + cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', None) + if cross_compile_archs_env is not None: + if cross_compile_archs is not None: + print( + f"{WARNING} env var `TORCH_CUDA_ARCH_LIST={cross_compile_archs_env}` overrides `cross_compile_archs={cross_compile_archs}`" + ) + cross_compile_archs = cross_compile_archs_env.replace(' ', ';') + else: + if cross_compile_archs is None: + cross_compile_archs = get_default_compute_capabilities() + ccs = cross_compile_archs.split(';') + + ccs = self.filter_ccs(ccs) + if len(ccs) == 0: + raise RuntimeError( + f"Unable to load {self.name} op due to no compute capabilities remaining after filtering") + + args = [] + for cc in ccs: + num = cc[0] + cc[2] + args.append(f'-gencode=arch=compute_{num},code=sm_{num}') + if cc.endswith('+PTX'): + args.append(f'-gencode=arch=compute_{num},code=compute_{num}') + + return args + + def filter_ccs(self, ccs: List[str]): + """ + Prune any compute capabilities that are not compatible with the builder. Should log + which CCs have been pruned. + """ + return ccs + + def version_dependent_macros(self): + # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456 + version_ge_1_1 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0): + version_ge_1_1 = ['-DVERSION_GE_1_1'] + version_ge_1_3 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): + version_ge_1_3 = ['-DVERSION_GE_1_3'] + version_ge_1_5 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): + version_ge_1_5 = ['-DVERSION_GE_1_5'] + return version_ge_1_1 + version_ge_1_3 + version_ge_1_5 + + def is_compatible(self, verbose=True): + return super().is_compatible(verbose) + + def builder(self): + try: + assert_no_cuda_mismatch(self.name) + self.build_for_cpu = False + except BaseException: + self.build_for_cpu = True + + if self.build_for_cpu: + from torch.utils.cpp_extension import CppExtension as ExtensionBuilder + else: + from torch.utils.cpp_extension import CUDAExtension as ExtensionBuilder + + compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} if self.build_for_cpu else \ + {'cxx': self.strip_empty_entries(self.cxx_args()), \ + 'nvcc': self.strip_empty_entries(self.nvcc_args())} + + cuda_ext = ExtensionBuilder(name=self.absolute_name(), + sources=self.strip_empty_entries(self.sources()), + include_dirs=self.strip_empty_entries(self.include_paths()), + libraries=self.strip_empty_entries(self.libraries_args()), + extra_compile_args=compile_args) + + if self.is_rocm_pytorch(): + # hip converts paths to absolute, this converts back to relative + sources = cuda_ext.sources + curr_file = Path(__file__).parent.parent # ds root + for i in range(len(sources)): + src = Path(sources[i]) + if src.is_absolute(): + sources[i] = str(src.relative_to(curr_file)) + else: + sources[i] = str(src) + cuda_ext.sources = sources + return cuda_ext + + def hipify_extension(self): + if self.is_rocm_pytorch(): + from torch.utils.hipify import hipify_python + hipify_python.hipify( + project_directory=os.getcwd(), + output_directory=os.getcwd(), + header_include_dirs=self.include_paths(), + includes=[os.path.join(os.getcwd(), '*')], + extra_files=[os.path.abspath(s) for s in self.sources()], + show_detailed=True, + is_pytorch_extension=True, + hipify_extra_files_only=True, + ) + + def cxx_args(self): + if sys.platform == "win32": + return ['-O2'] + else: + return ['-O3', '-std=c++14', '-g', '-Wno-reorder'] + + def nvcc_args(self): + if self.build_for_cpu: + return [] + args = ['-O3'] + if self.is_rocm_pytorch(): + ROCM_MAJOR, ROCM_MINOR = self.installed_rocm_version() + args += [ + '-std=c++14', '-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_CONVERSIONS__', + '-U__HIP_NO_HALF2_OPERATORS__', + '-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR, + '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR + ] + else: + cuda_major, _ = installed_cuda_version() + args += [ + '-allow-unsupported-compiler' if sys.platform == "win32" else '', '--use_fast_math', + '-std=c++17' if sys.platform == "win32" and cuda_major > 10 else '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__' + ] + if os.environ.get('DS_DEBUG_CUDA_BUILD', '0') == '1': + args.append('--ptxas-options=-v') + args += self.compute_capability_args() + return args + + def libraries_args(self): + if self.build_for_cpu: + return [] + + if sys.platform == "win32": + return ['cublas', 'curand'] + else: + return [] + + +class TorchCPUOpBuilder(CUDAOpBuilder): + + def extra_ldflags(self): + if self.build_for_cpu: + return ['-fopenmp'] + + if not self.is_rocm_pytorch(): + return ['-lcurand'] + + return [] + + def cxx_args(self): + import torch + args = [] + if not self.build_for_cpu: + if not self.is_rocm_pytorch(): + CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64") + else: + CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib") + + args += super().cxx_args() + args += [ + f'-L{CUDA_LIB64}', + '-lcudart', + '-lcublas', + '-g', + ] + + CPU_ARCH = self.cpu_arch() + SIMD_WIDTH = self.simd_width() + CUDA_ENABLE = self.is_cuda_enable() + args += [ + CPU_ARCH, + '-fopenmp', + SIMD_WIDTH, + CUDA_ENABLE, + ] + + return args + +class InferenceBuilder(CUDAOpBuilder): + BUILD_VAR = "DS_BUILD_TRANSFORMER_INFERENCE" + NAME = "transformer_inference" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'cai.inference.{self.NAME}_op' + + def is_compatible(self, verbose=True): + try: + import torch + except ImportError: + self.warning("Please install torch if trying to pre-compile inference kernels") + return False + + cuda_okay = True + if not self.is_rocm_pytorch() and torch.cuda.is_available(): + sys_cuda_major, _ = installed_cuda_version() + torch_cuda_major = int(torch.version.cuda.split('.')[0]) + cuda_capability = torch.cuda.get_device_properties(0).major + if cuda_capability < 6: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + cuda_okay = False + if cuda_capability >= 8: + if torch_cuda_major < 11 or sys_cuda_major < 11: + self.warning("On Ampere and higher architectures please use CUDA 11+") + cuda_okay = False + return super().is_compatible(verbose) and cuda_okay + + def filter_ccs(self, ccs): + ccs_retained = [] + ccs_pruned = [] + for cc in ccs: + if int(cc[0]) >= 6: + ccs_retained.append(cc) + else: + ccs_pruned.append(cc) + if len(ccs_pruned) > 0: + self.warning(f"Filtered compute capabilities {ccs_pruned}") + return ccs_retained + + def sources(self): + return [ + 'gptq/csrc/pt_binding.cpp', + 'gptq/csrc/gptq_act_linear.cu', + ] + + def extra_ldflags(self): + if not self.is_rocm_pytorch(): + return ['-lcurand', '-L/home/lcxk/data3/anaconda3/envs/triton/lib'] + else: + return [] + + def include_paths(self): + return ['gptq/csrc/includes'] + + +builder = InferenceBuilder() +inference_cuda_module = builder.load() \ No newline at end of file diff --git a/tests/test_gptq/linear_act_fusion_bench.py b/tests/test_gptq/linear_act_fusion_bench.py new file mode 100644 index 000000000000..13d716412f47 --- /dev/null +++ b/tests/test_gptq/linear_act_fusion_bench.py @@ -0,0 +1,385 @@ + +import torch +import torch.nn as nn + +import time +from argparse import ArgumentParser + +import transformers +from colossalai.gptq.gptq_utils import GPTQ +from colossalai.gptq.gptq_utils.utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders +from colossalai.gptq.gptq_utils import quant +from colossalai.gptq.gptq_utils.quant import Quantizer +from colossalai.gptq.cai_gptq.gptq_op import CaiGPTQLinearOp +import math +import numpy as np +from colossalai.gptq import CaiInferenceConfig +import csv + +class MLinear(nn.Module): + def __init__(self, infeature, outfeature): + super(MLinear, self).__init__() + self.linear = torch.nn.Linear(infeature, outfeature, dtype=torch.float16) + def forward(self, x): + out = self.linear(x) + return out + +@torch.no_grad() +def model_quant(model, inps, dev, args): + print('Starting ...') + layers = [model] + layers[0] = layers[0].to(dev) + dtype = next(iter(model.parameters())).dtype + cache = {'i': 0} + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + def forward(self, inp, **kwargs): + inps[cache['i']] = inp + cache['i'] += 1 + raise ValueError + layers[0] = Catcher(layers[0]) + # for batch in inps: + try: + model(inps.to(dev)) + except ValueError: + pass + layers[0] = layers[0].module + + layers[0] = layers[0].cpu() + + # outs = torch.zeros_like(inps) + outs = torch.zeros(inps.shape[0], layers[0].linear.weight.shape[0]) + + print('Ready.') + + quantizers = {} + for i in range(len(layers)): + layer = layers[i].to(dev) + + subset = find_layers(layer) + gptq = {} + for name in subset: + gptq[name] = GPTQ(subset[name]) + gptq[name].quantizer = Quantizer() + gptq[name].quantizer.configure( args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits ) + + def add_batch(name): + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + return tmp + + handles = [] + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0))[0] + + for h in handles: + h.remove() + for name in subset: + print(f'Quantizing {name} in layer {i+1}/{len(layers)}...') + scale,zero,g_idx,error= gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) + quantizers['%s' % (name)] = (gptq[name].quantizer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) + gptq[name].free() + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0))[0] + + layers[i] = layer.cpu() + del layer + del gptq + torch.cuda.empty_cache() + + inps, outs = outs, inps + + return quantizers + +def model_pack(model, quantizers, wbits, groupsize): + layers = find_layers(model) + layers = {n: layers[n] for n in quantizers} + quant.make_quant_linear(model, quantizers, wbits, groupsize) + qlayers = find_layers(model, [quant.QuantLinear]) + print('Packing ...') + for name in qlayers: + quantizers[name], scale, zero, g_idx = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + print('Done.') + return qlayers['linear'] + + + +def cai_linear_pack(linear, scales, zeros, + out_qweight, out_qscales, out_qzeros, qg_idx, + infeatures, groupsize, bits): + g_idx = qg_idx.clone() if qg_idx is not None else torch.tensor([i // groupsize for i in range(infeatures)], dtype=torch.int32) + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + half_scales = scales.clone().half() + # print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape) + + out_qscales.data.copy_(scales) + + wn = 16 + pbits = 64 + ptype = torch.int64 + unsign_type = np.uint64 + sign_type = np.int64 + + # wn = 8 + # pbits = 32 + # ptype = torch.int32 + # unsign_type = np.uint32 + # sign_type = np.int32 + + intweight = [] + for idx in range(infeatures): + intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(unsign_type) + qweight = np.zeros((intweight.shape[0] // pbits * bits, intweight.shape[1]), dtype=unsign_type) + + i = 0 + row = 0 + # print("weight shape ", intweight.shape, qweight.shape, out_qweight.shape, bits) + # print("weight shape ", intweight[0].shape, qweight[0].shape, out_qweight[0].shape) + # print("weight value ", intweight[0], qweight[0]) + + while row < qweight.shape[0]: + if bits in [2, 4, 8]: + for j in range(i, i + (pbits // bits)): + qweight[row] |= intweight[j] << (bits * (j - i)) + i += pbits // bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qweight = qweight.astype(sign_type) + qweight1 = torch.from_numpy(qweight) + qweight1 = qweight1.contiguous().to("cuda") + out_qweight.data.copy_(qweight1) + + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * bits), dtype=unsign_type) + zeros -= 1 + zeros = zeros.numpy().astype(unsign_type) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if bits in [2, 4, 8]: + for j in range(i, i + (pbits // bits)): + qzeros[:, col] |= zeros[:, j] << (bits * (j - i)) + i += pbits // bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qzeros = qzeros.astype(sign_type) + qzeros = torch.from_numpy(qzeros) + qzeros = qzeros.to("cuda") + out_qzeros.data.copy_(qzeros) + + return out_qweight, out_qscales, out_qzeros + +def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize): + layers = find_layers(model) + layers = {n: layers[n] for n in quantizers} + with torch.no_grad(): + for name in layers: + _, scale, zero, g_idx = quantizers[name] + qweight, qscales, qzeros = cai_linear_pack(layers[name], scale, zero, + qweight, qscales, qzeros, g_idx, + layers[name].weight.shape[-1], groupsize, wbits) + + # print("cai pack", layers) + return qweight, qscales, qzeros +if __name__ == "__main__": + + + parser = ArgumentParser() + parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.') + parser.add_argument('--wbits', type=int, default=4, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.') + parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.') + parser.add_argument('--nsamples', type=int, default=1, help='Number of calibration data samples.') + parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.') + parser.add_argument('--groupsize', type=int, default=128, help='Groupsize to use for quantization; default uses full row.') + parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic') + args = parser.parse_args() + infeature = 8192 + outfeature = 8192 + + weight = torch.randn(outfeature, infeature).to(torch.float16).to(torch.cuda.current_device()) + bias = torch.zeros(outfeature).to(torch.float16).to(torch.cuda.current_device()) + wn = 16 + ptype = torch.int64 + + # wn = 8 + # ptype = torch.int32 + + qweight = torch.zeros(infeature//wn, outfeature, dtype=ptype, device=torch.cuda.current_device()).contiguous() + qscales = torch.zeros(infeature//args.groupsize, outfeature, dtype=torch.float16, device=torch.cuda.current_device()).contiguous() + qzeros = torch.zeros(infeature//args.groupsize, outfeature//wn, dtype=ptype, device=torch.cuda.current_device()).contiguous() + + + linear = MLinear(infeature, outfeature) + linear.to(torch.cuda.current_device()) + with torch.no_grad(): + linear.linear.weight.data.copy_(weight) + linear.linear.bias.data.copy_(bias) + inps = torch.randn(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device()) + quantizers = model_quant(linear, inps, torch.cuda.current_device(), args) + qweight, qscales, qzeros = model_cai_pack(linear, quantizers, qweight, qscales, qzeros, args.wbits, args.groupsize) + + + batch_inps = torch.randn(1, 16384, infeature).to(torch.float16).to(torch.cuda.current_device()) + + gptq_linear_time = 0 + torch_linear_time = 0 + warm_up_iter = 2 + benchmark_iter = 100 + + act_func = nn.ReLU() + linear.to("cuda") + for i in range(0, warm_up_iter): + with torch.no_grad(): + torch_out = act_func(inps) + # torch_out = inps + # print(f"torch out {torch_out}") + torch_out = linear(torch_out) + torch.cuda.synchronize() + + time_start = time.time() + for i in range(0, benchmark_iter): + with torch.no_grad(): + torch_out = act_func(inps) + # torch_out = inps + torch_out = linear(torch_out) + torch.cuda.synchronize() + + time_end = time.time() + torch_linear_time = time_end - time_start + + + time_start = time.time() + for i in range(0, benchmark_iter): + with torch.no_grad(): + torch_out = act_func(batch_inps) + # torch_out = inps + torch_out = linear(torch_out) + torch.cuda.synchronize() + + time_end = time.time() + torch_batch_linear_time = time_end - time_start + + linear.to("cpu") + + gptq_model = model_pack(linear, quantizers, args.wbits, args.groupsize) + gptq_model.to(torch.cuda.current_device()) + + # gptq_model = linear + + for i in range(0, warm_up_iter): + with torch.no_grad(): + gptq_out = act_func(inps) + # gptq_out = inps + gptq_out = gptq_model(gptq_out) + torch.cuda.synchronize() + + time_start = time.time() + for i in range(0, benchmark_iter): + with torch.no_grad(): + gptq_out = act_func(inps) + # gptq_out = inps + gptq_out = gptq_model(gptq_out) + torch.cuda.synchronize() + + time_end = time.time() + + gptq_linear_time = time_end - time_start + + for i in range(0, warm_up_iter): + with torch.no_grad(): + gptq_out = act_func(batch_inps) + # gptq_out = inps + gptq_out = gptq_model(gptq_out) + torch.cuda.synchronize() + + time_start = time.time() + for i in range(0, benchmark_iter): + with torch.no_grad(): + gptq_out = act_func(batch_inps) + # gptq_out = inps + gptq_out = gptq_model(gptq_out) + torch.cuda.synchronize() + + time_end = time.time() + + gptq_batch_linear_time = time_end - time_start + + # qweight = torch.cat((qweight, qweight, qweight), dim=0).contiguous() + # qscales = torch.cat((qscales, qscales, qscales), dim=0).contiguous() + # qzeros = torch.cat((qzeros, qzeros, qzeros), dim=0).contiguous() + # bias = torch.cat((bias, bias, bias), dim=0).contiguous() + qkv_fused = False + cai_inf_config = CaiInferenceConfig(fp16=True) + + cai_linear = CaiGPTQLinearOp(cai_inf_config) + + print("cai linear") + for i in range(0, warm_up_iter): + with torch.no_grad(): + cai_out = cai_linear(inps, + qweight, + qscales, + qzeros, + act_type=0, + bias = bias, + qkv_fused = qkv_fused) + torch.cuda.synchronize() + + + print("warm up cai linear") + + # f = open('cai_time.csv', 'w') + # writer = csv.writer(f) + + + for i in range(0, warm_up_iter): + with torch.no_grad(): + cai_out = cai_linear(batch_inps, + qweight, + qscales, + qzeros, + act_type=0, + bias = bias, + qkv_fused = qkv_fused) + torch.cuda.synchronize() + + cai_linear_time = time_end - time_start + # print("block dim x:{}, block dim y:{}, time: {:.8f} ".format(i, j, cai_linear_time/benchmark_iter)) + # row=[i, j, cai_linear_time/benchmark_iter] + + + time_start = time.time() + for k in range(0, benchmark_iter): + with torch.no_grad(): + cai_out = cai_linear(batch_inps, + qweight, + qscales, + qzeros, + act_type=0, + bias = bias, + qkv_fused = qkv_fused) + torch.cuda.synchronize() + time_end = time.time() + + batch_cai_linear_time = time_end - time_start + + print("torch time: {:.8f}".format(torch_linear_time/benchmark_iter)) + print("gptq time:{:.8f}".format( gptq_linear_time/benchmark_iter)) + print("cai gptq time:{:.8f}".format( cai_linear_time/benchmark_iter)) + + print("batch torch time: {:.8f}".format(torch_batch_linear_time/benchmark_iter)) + print("batch gptq time:{:.8f}".format( gptq_batch_linear_time/benchmark_iter)) + print("batch cai gptq time:{:.8f}".format( batch_cai_linear_time/benchmark_iter)) diff --git a/tests/test_gptq/quant_llama.py b/tests/test_gptq/quant_llama.py new file mode 100644 index 000000000000..52ac0d81f6b1 --- /dev/null +++ b/tests/test_gptq/quant_llama.py @@ -0,0 +1,569 @@ +import argparse +import time +import numpy as np +import torch +import torch.nn as nn +from colossalai.gptq.gptq_utils import quant +from colossalai.gptq import cai_gptq + +from colossalai.gptq.gptq_utils import GPTQ, Observer +from colossalai.gptq.gptq_utils.utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders, export_quant_table, gen_conditions +from texttable import Texttable +from colossalai.gptq import CaiInferenceConfig +from transformers import LlamaForCausalLM, LlamaTokenizer + +import csv + +def get_llama(model): + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + from transformers import LlamaForCausalLM, LlamaConfig, LlamaModel + if args.debug: + llama_kwargs= {"bos_token_id": 0, + "eos_token_id": 1, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 2048, + "max_sequence_length": 2048, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "pad_token_id": -1, + "rms_norm_eps": 1e-06, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "use_cache": True, + "vocab_size": 32000 + } + configuration = LlamaConfig( **llama_kwargs + ) + model = LlamaForCausalLM(configuration) + else: + model = LlamaForCausalLM.from_pretrained(model, torch_dtype=torch.float16) + + # # LlamaForCausalLM + model.seqlen = 2048 + return model + + +@torch.no_grad() +def llama_sequential(model, dataloader, dev): + print('Starting ...') + + use_cache = model.config.use_cache + model.config.use_cache = False + layers = model.model.layers + + model.model.embed_tokens = model.model.embed_tokens.to(dev) + model.model.norm = model.model.norm.to(dev) + layers[0] = layers[0].to(dev) + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) + cache = {'i': 0, 'attention_mask': None} + + class Catcher(nn.Module): + + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, inp, **kwargs): + inps[cache['i']] = inp + cache['i'] += 1 + cache['attention_mask'] = kwargs['attention_mask'] + cache['position_ids'] = kwargs['position_ids'] + raise ValueError + + layers[0] = Catcher(layers[0]) + for batch in dataloader: + try: + model(batch[0].to(dev)) + except ValueError: + pass + layers[0] = layers[0].module + + layers[0] = layers[0].cpu() + model.model.embed_tokens = model.model.embed_tokens.cpu() + model.model.norm = model.model.norm.cpu() + torch.cuda.empty_cache() + + outs = torch.zeros_like(inps) + attention_mask = cache['attention_mask'] + position_ids = cache['position_ids'] + + print('Ready.') + + quantizers = {} + observer = Observer() + for i in range(len(layers)): + + print(f'Quantizing layer {i+1}/{len(layers)}..') + print('+------------------+--------------+------------+-----------+-------+') + print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |') + print('+==================+==============+============+===========+=======+') + + layer = layers[i].to(dev) + full = find_layers(layer) + if args.true_sequential: + sequential = [['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], ['self_attn.o_proj'], ['mlp.up_proj', 'mlp.gate_proj'], ['mlp.down_proj']] + else: + sequential = [list(full.keys())] + + for names in sequential: + subset = {n: full[n] for n in names} + gptq = {} + for name in subset: + gptq[name] = GPTQ(subset[name], observe=args.observe) + gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False) + + def add_batch(name): + + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + + return tmp + + handles = [] + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] + for h in handles: + h.remove() + + for name in subset: + scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, name=name) + quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), args.wbits, args.groupsize) + + if args.observe: + observer.submit(name=name, layerid=i, gptq=gptq[name], error=error) + else: + gptq[name].free() + + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] + + layers[i] = layer.cpu() + del layer + del gptq + torch.cuda.empty_cache() + + inps, outs = outs, inps + print('+------------------+--------------+------------+-----------+-------+') + print('\n') + + if args.observe: + observer.print() + conditions = gen_conditions(args.wbits, args.groupsize) + for item in observer.items(): + name = item[0] + layerid = item[1] + gptq = item[2]['gptq'] + error = item[2]['error'] + target = error / 2 + + table = Texttable() + table.header(['wbits', 'groupsize', 'error']) + table.set_cols_dtype(['i', 'i', 'f']) + table.add_row([args.wbits, args.groupsize, error]) + + print('Optimizing {} {} ..'.format(name, layerid)) + for wbits, groupsize in conditions: + + if error < target: + # if error dropped 50%, skip + break + + gptq.quantizer.configure(wbits, perchannel=True, sym=args.sym, mse=False) + + scale, zero, g_idx, error = gptq.fasterquant(percdamp=args.percdamp, groupsize=groupsize, actorder=args.act_order, name=name) + + table.add_row([wbits, groupsize, error]) + quantizers['model.layers.%d.%s' % (layerid, name)] = (gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), wbits, groupsize) + + print(table.draw()) + print('\n') + gptq.layer.to('cpu') + gptq.free() + + model.config.use_cache = use_cache + + return quantizers + + +# TODO: perform packing on GPU +def cai_llama_pack(model, quantizers, wbits, groupsize): + layers = find_layers(model) + # print(f"model {model}") + # print(f"layers {layers}") + + layers = {n: layers[n] for n in quantizers} + # print(f"quantizers {quantizers}") + cai_gptq.make_cai_quant_linear(model, quantizers, wbits, groupsize) + qlayers = find_layers(model, [cai_gptq.CaiQuantLinear]) + print('Packing ...') + for name in qlayers: + print(name) + quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + print('Done.') + return model + +def gptq_llama_pack(model, quantizers, wbits, groupsize): + layers = find_layers(model) + # print(f"model {model}") + # print(f"layers {layers}") + + layers = {n: layers[n] for n in quantizers} + # print(f"quantizers {quantizers}") + quant.make_quant_linear(model, quantizers, wbits, groupsize) + qlayers = find_layers(model, [quant.QuantLinear]) + print('Packing ...') + for name in qlayers: + print(name) + quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + print('Done.') + return model + + +def cai_load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True): + from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils + config = LlamaConfig.from_pretrained(model) + + def noop(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = noop + torch.nn.init.uniform_ = noop + torch.nn.init.normal_ = noop + + torch.set_default_dtype(torch.half) + modeling_utils._init_weights = False + torch.set_default_dtype(torch.half) + model = LlamaForCausalLM(config) + torch.set_default_dtype(torch.float) + if eval: + model = model.eval() + layers = find_layers(model) + for name in ['lm_head']: + if name in layers: + del layers[name] + cai_gptq.make_cai_quant_linear(model, layers, wbits, groupsize) + + del layers + + print('Loading model ...') + if checkpoint.endswith('.safetensors'): + from safetensors.torch import load_file as safe_load + model.load_state_dict(safe_load(checkpoint)) + else: + model.load_state_dict(torch.load(checkpoint)) + + print('Done.') + + return model + + +def gptq_load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True): + from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils + config = LlamaConfig.from_pretrained(model) + + def noop(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = noop + torch.nn.init.uniform_ = noop + torch.nn.init.normal_ = noop + + torch.set_default_dtype(torch.half) + modeling_utils._init_weights = False + torch.set_default_dtype(torch.half) + model = LlamaForCausalLM(config) + torch.set_default_dtype(torch.float) + if eval: + model = model.eval() + layers = find_layers(model) + for name in ['lm_head']: + if name in layers: + del layers[name] + quant.make_quant_linear(model, layers, wbits, groupsize) + + del layers + + print('Loading model ...') + if checkpoint.endswith('.safetensors'): + from safetensors.torch import load_file as safe_load + model.load_state_dict(safe_load(checkpoint)) + else: + model.load_state_dict(torch.load(checkpoint)) + + print('Done.') + + return model + +all_perfs = [] +now_perf=[] + +def print_perf_stats(latency_set, config, warmup=3): + global now_perf + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = getattr(config, "num_layers", config.num_hidden_layers) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 + # if args.dtype == "float16": + # num_bytes = 2 + # elif args.dtype == "float32": + # num_bytes = 4 + # else: + # num_bytes = 1 + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1/avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1/avg * num_parameters * num_bytes * args.batch_size / 1e12)) + print("Alloc GPU Mem: {0:8.2f} GB".format(torch.cuda.memory_allocated() / 1e9)) + print("Max alloc GPU Mem: {0:8.2f} GB".format(torch.cuda.max_memory_allocated()/1e9)) + row = [args.batch_size, args.input_len, args.max_new_tokens, "{0:8.2f}".format(avg * 1000), + "{0:8.2f}".format(torch.cuda.memory_allocated() / 1e9), + "{0:8.2f}".format(torch.cuda.max_memory_allocated()/1e9)] + with open('./{}_profile.csv'.format(args.model_type), 'a', encoding='UTF8') as f: + # create the csv writer + writer = csv.writer(f) + + # write a row to the csv file + writer.writerow(row) + + now_perf.append("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + now_perf.append("Alloc GPU Mem: {0:8.2f} GB".format(torch.cuda.memory_allocated() / 1e9)) + now_perf.append("Max alloc GPU Mem: {0:8.2f} GB".format(torch.cuda.max_memory_allocated()/1e9)) + + all_perfs.append(now_perf) + now_perf = [] + +def benchmark(model): + + input_tokens = {"input_ids":torch.randint(1, 1000, (args.batch_size, args.input_len), device=DEV), + "attention_mask":torch.ones((args.batch_size, args.input_len), device=DEV)} + torch.cuda.synchronize() + iters = 10 if args.benchmark else 2 #warmup + print(f"model config {model.config}") + + times = [] + warmup=3 + prof_flag = 0 + generate_kwargs = dict(max_new_tokens=args.max_new_tokens, do_sample=False) + torch.cuda.reset_peak_memory_stats() + for i in range(iters): + if i >= warmup: + prof_flag=1 + torch.cuda.synchronize() + start = time.time() + outputs = model.generate(**input_tokens, + **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + times.append(end - start) + print("outpus shape: ", outputs.shape) + print(args) + print("input batch, input len, out len: ",args.batch_size, args.input_len, args.max_new_tokens) + # if args.local_rank == 0: + now_perf.extend(["input batch, input len, out len: ",args.batch_size, args.input_len, args.max_new_tokens]) + print_perf_stats(map(lambda t: t / args.max_new_tokens, times), model.config) + +def test(model_1, model_2): + # input_tokens = {"input_ids":torch.randint(1, 1000, (args.batch_size, args.input_len), device=DEV), + # "attention_mask":torch.ones((args.batch_size, args.input_len), device=DEV)} + generate_kwargs = dict(max_new_tokens=args.max_new_tokens, do_sample=False) + + + tokenizer = LlamaTokenizer.from_pretrained(args.model) + tokenizer.pad_token_id = tokenizer.unk_token_id + + text = "how is weather today? I want to know the weather of beijing. " + text = "how are you?" + + inputs = [text] + input_tokens = tokenizer.batch_encode_plus(inputs, padding = True, return_tensors="pt") + + input_len = 0 + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) + # print(input_tokens[t].shape) + input_len = input_tokens[t].shape[1] + + outputs_1 = model_1.generate(**input_tokens, + **generate_kwargs) + print("model 1 done") + out_1 = tokenizer.batch_decode(outputs_1) + + print("decode out:", out_1) + if model_2 is None: + return + outputs_2 = model_2.generate(**input_tokens, + **generate_kwargs) + print("model 2 done") + + out_2 = tokenizer.batch_decode(outputs_2) + + ret = torch.allclose(outputs_1, outputs_2) + print("allclose is ", ret) + + print("decode out:", out_2) + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + + parser.add_argument('model', type=str, help='llama model to load') + parser.add_argument('dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.') + parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.') + parser.add_argument('--nsamples', type=int, default=1, help='Number of calibration data samples.') + parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.') + parser.add_argument('--nearest', action='store_true', help='Whether to run the RTN baseline.') + parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.') + parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.') + parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.') + parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.') + parser.add_argument('--save_safetensors', type=str, default='', help='Save quantized `.safetensors` checkpoint under this name.') + parser.add_argument('--load', type=str, default='', help='Load quantized model.') + parser.add_argument('--benchmark', action='store_true', help='Number of tokens to use for benchmarking.') + parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.') + parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.') + parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic') + parser.add_argument('--true-sequential', action='store_true', help='Whether to run in true sequential model.') + parser.add_argument('--layers-dist', type=str, default='', help='Distribution of layers across GPUs. e.g. 2:1:1 for 2 layers on GPU 0, 1 layer on GPU 1, and 1 layer on GPU 2. Any remaining layers will be assigned to your last GPU.') + parser.add_argument('--observe', + action='store_true', + help='Auto upgrade layer precision to higher precision, for example int2 to int4, groupsize 128 to 64. \ + When this feature enabled, `--save` or `--save_safetensors` would be disable.') + parser.add_argument('--quant-directory', type=str, default=None, help='Specify the directory for export quantization parameters to toml format. `None` means no export by default.') + parser.add_argument('--max_new_tokens', type=int, default=32, help='Max new tokens to generate.') + parser.add_argument('--batch_size', type=int, default=1, help='Batch size to generate.') + parser.add_argument('--input_len', type=int, default=128, help='Batch size to generate.') + parser.add_argument('--model_type', type=str, choices=['cai', 'gptq', 'torch'], default='torch', help='Batch size to generate.') + parser.add_argument('--debug', action='store_true', help='Whether to debug or not') + + args = parser.parse_args() + + model_packed = False + if type(args.load) is not str: + args.load = args.load.as_posix() + + if args.load: + if args.model_type == "gptq": + model = gptq_load_quant(args.model, args.load, args.wbits, args.groupsize) + elif args.model_type == "cai": + model = cai_load_quant(args.model, args.load, args.wbits, args.groupsize) + else: + model = get_llama(args.model) + model.half() + + if not args.load and args.wbits < 16 and not args.nearest and args.model_type in ['cai', 'gptq']: + dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen) + tick = time.time() + quantizers = llama_sequential(model, dataloader, DEV) + if args.model_type == "cai": + cai_llama_pack(model, quantizers, args.wbits, args.groupsize) + elif args.model_type == "gptq": + gptq_llama_pack(model, quantizers, args.wbits, args.groupsize) + model_packed = True + print(time.time() - tick) + + + if args.quant_directory is not None: + export_quant_table(quantizers, args.quant_directory) + + if not args.observe and args.save and args.model_type in ['cai', 'gptq']: + if not model_packed: + llama_pack(model, quantizers, args.wbits, args.groupsize) + model_packed = True + torch.save(model.state_dict(), args.save) + + if not args.observe and args.save_safetensors and args.model_type in ['cai', 'gptq']: + if not model_packed: + llama_pack(model, quantizers, args.wbits, args.groupsize) + from safetensors.torch import save_file as safe_save + state_dict = model.state_dict() + state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} + safe_save(state_dict, args.save_safetensors) + + if args.benchmark: + # model = model.to(DEV) + # print(f"model config {model.config.num_key_value_heads}") + + # if args.model_type == "cai": + # cai_inf_config = CaiInferenceConfig(fp16=True, + # device=torch.cuda.current_device(), + # gptq=True, + # gptq_group_size=128, + # gptq_quant_bits=4) + # model = convert_to_ds_model(model, cai_inf_config) + # model.cuda().to(torch.cuda.current_device()) + + + torch_model = get_llama(args.model) + torch_model.half() + torch_model = torch_model.to(DEV) + + gptq_model = gptq_load_quant(args.model, "llama7b-4bit-128g-gptq-nao.pt", args.wbits, args.groupsize) + gptq_model = gptq_model.to(DEV) + + model = cai_load_quant(args.model, args.load, args.wbits, args.groupsize) + model = model.to(DEV) + + + test(torch_model, model) + test(gptq_model, None) + + print("torch_model ", torch_model) + print("gptq_model ", gptq_model) + print("cai_model ", model) + torch_qkv_out = torch_model.model.layers[0].self_attn.qkv_out + cai_qkv_out = model.model.layers[0].self_attn.qkv_out + gptq_qkv_out = gptq_model.model.layers[0].self_attn.qkv_out + + gptq_out = gptq_model.model.layers[0].self_attn.q_proj.scales + cai_out = model.model.layers[0].self_attn.q_proj.scales + + mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) + max_diff = torch.max(torch.abs(cai_out - gptq_out)) + print("cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + for i in range(3): + cai_out = cai_qkv_out[i] + torch_out = torch_qkv_out[i] + gptq_out = gptq_qkv_out[i] + mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) + max_diff = torch.max(torch.abs(cai_out - gptq_out)) + print("cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + mean_diff = torch.mean(torch.abs(torch_out - gptq_out)) + max_diff = torch.max(torch.abs(torch_out - gptq_out)) + print("torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + mean_diff = torch.mean(torch.abs(torch_out - cai_out)) + max_diff = torch.max(torch.abs(torch_out - cai_out)) + print("torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + + # # for batch in [1, 2, 4, 8, 16, 32]: + # for batch in [1]: + # args.batch_size = batch + # # for in_len in [128, 256, 512, 1024, 2048]: + # for in_len in [1024]: + # args.input_len = in_len + # benchmark(model) + # # for info in all_perfs: + # # print(info) + # # # all_perfs = [] \ No newline at end of file diff --git a/tests/test_gptq/run_gptq.sh b/tests/test_gptq/run_gptq.sh new file mode 100644 index 000000000000..03aaca8f60df --- /dev/null +++ b/tests/test_gptq/run_gptq.sh @@ -0,0 +1,19 @@ +# OMP_NUM_THREADS=48 CUDA_VISIBLE_DEVICES=5 python quant_llama.py /data/scratch/llama-7b-hf c4 \ +# --wbits 4 --true-sequential --groupsize 128 --save ./llama7b-4bit-128g-cai-nao.pt\ +# --benchmark --model_type cai --input_len 1024 --max_new_tokens 128 --batch_size 1 + +# OMP_NUM_THREADS=48 CUDA_VISIBLE_DEVICES=5 python quant_llama.py /data/scratch/llama-7b-hf c4 \ +# --wbits 4 --true-sequential --groupsize 128 --save ./llama7b-4bit-128g-gptq-nao.pt\ +# --benchmark --model_type gptq --input_len 1024 --max_new_tokens 128 --batch_size 1 + +OMP_NUM_THREADS=48 CUDA_VISIBLE_DEVICES=5 python quant_llama.py /data/scratch/llama-7b-hf c4 \ + --wbits 4 --true-sequential --act-order --groupsize 128 --load ./llama7b-4bit-128g-cai-nao.pt\ + --benchmark --model_type cai --input_len 1024 --max_new_tokens 128 --batch_size 1 + +# OMP_NUM_THREADS=48 CUDA_VISIBLE_DEVICES=5 python quant_llama.py /data/scratch/llama-7b-hf c4 \ +# --wbits 4 --true-sequential --act-order --groupsize 128 --load /llama7b-4bit-128g-gptq-nao.pt \ +# --benchmark --model_type gptq --input_len 1024 --max_new_tokens 128 --batch_size 1 + +# OMP_NUM_THREADS=48 CUDA_VISIBLE_DEVICES=4 python quant_llama.py /data/scratch/llama-13b-hf c4 \ +# --wbits 4 --true-sequential --act-order --groupsize 128 \ +# --benchmark --model_type torch --input_len 1024 --max_new_tokens 128 --batch_size 1 diff --git a/tests/test_gptq/test_linear_act_fusion.py b/tests/test_gptq/test_linear_act_fusion.py new file mode 100644 index 000000000000..8079a9966843 --- /dev/null +++ b/tests/test_gptq/test_linear_act_fusion.py @@ -0,0 +1,402 @@ + +import torch +import torch.nn as nn + +import time +from argparse import ArgumentParser + +import transformers +from colossalai.gptq.gptq_utils import GPTQ +from colossalai.gptq.gptq_utils.utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders +from colossalai.gptq.gptq_utils import quant +from colossalai.gptq.gptq_utils.quant import Quantizer +from colossalai.gptq.cai_gptq.gptq_op import CaiGPTQLinearOp +import math +import numpy as np +from colossalai.gptq import CaiInferenceConfig +import csv + +class MLinear(nn.Module): + def __init__(self, infeature, outfeature): + super(MLinear, self).__init__() + self.linear = torch.nn.Linear(infeature, outfeature, dtype=torch.float16) + def forward(self, x): + out = self.linear(x) + return out + +@torch.no_grad() +def model_quant(model, inps, dev, args): + print('Starting ...') + layers = [model] + layers[0] = layers[0].to(dev) + dtype = next(iter(model.parameters())).dtype + cache = {'i': 0} + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + def forward(self, inp, **kwargs): + inps[cache['i']] = inp + cache['i'] += 1 + raise ValueError + layers[0] = Catcher(layers[0]) + # for batch in inps: + try: + model(inps.to(dev)) + except ValueError: + pass + layers[0] = layers[0].module + + layers[0] = layers[0].cpu() + + # outs = torch.zeros_like(inps) + outs = torch.zeros(inps.shape[0], layers[0].linear.weight.shape[0]) + + print('Ready.') + + quantizers = {} + for i in range(len(layers)): + layer = layers[i].to(dev) + + subset = find_layers(layer) + gptq = {} + for name in subset: + gptq[name] = GPTQ(subset[name]) + gptq[name].quantizer = Quantizer() + gptq[name].quantizer.configure( args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits ) + + def add_batch(name): + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + return tmp + + handles = [] + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0))[0] + + for h in handles: + h.remove() + for name in subset: + print(f'Quantizing {name} in layer {i+1}/{len(layers)}...') + scale,zero,g_idx,error= gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) + quantizers['%s' % (name)] = (gptq[name].quantizer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) + gptq[name].free() + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0))[0] + + layers[i] = layer.cpu() + del layer + del gptq + torch.cuda.empty_cache() + + inps, outs = outs, inps + + return quantizers + +def model_pack(model, quantizers, wbits, groupsize): + layers = find_layers(model) + layers = {n: layers[n] for n in quantizers} + quant.make_quant_linear(model, quantizers, wbits, groupsize) + qlayers = find_layers(model, [quant.QuantLinear]) + print('Packing ...') + for name in qlayers: + quantizers[name], scale, zero, g_idx = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + print('Done.') + return qlayers['linear'] + + + +def cai_linear_pack(linear, scales, zeros, + out_qweight, out_qscales, out_qzeros, qg_idx, + infeatures, groupsize, bits): + g_idx = qg_idx.clone() if qg_idx is not None else torch.tensor([i // groupsize for i in range(infeatures)], dtype=torch.int32) + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + half_scales = scales.clone().half() + # print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape) + + out_qscales.data.copy_(scales) + + wn = 16 + pbits = 64 + ptype = torch.int64 + unsign_type = np.uint64 + sign_type = np.int64 + + # wn = 8 + # pbits = 32 + # ptype = torch.int32 + # unsign_type = np.uint32 + # sign_type = np.int32 + + intweight = [] + for idx in range(infeatures): + intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(unsign_type) + qweight = np.zeros((intweight.shape[0] // pbits * bits, intweight.shape[1]), dtype=unsign_type) + + i = 0 + row = 0 + # print("weight shape ", intweight.shape, qweight.shape, out_qweight.shape, bits) + # print("weight shape ", intweight[0].shape, qweight[0].shape, out_qweight[0].shape) + # print("weight value ", intweight[0], qweight[0]) + + while row < qweight.shape[0]: + if bits in [2, 4, 8]: + for j in range(i, i + (pbits // bits)): + qweight[row] |= intweight[j] << (bits * (j - i)) + i += pbits // bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qweight = qweight.astype(sign_type) + qweight1 = torch.from_numpy(qweight) + qweight1 = qweight1.contiguous().to("cuda") + out_qweight.data.copy_(qweight1) + + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * bits), dtype=unsign_type) + zeros -= 1 + zeros = zeros.numpy().astype(unsign_type) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if bits in [2, 4, 8]: + for j in range(i, i + (pbits // bits)): + qzeros[:, col] |= zeros[:, j] << (bits * (j - i)) + i += pbits // bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qzeros = qzeros.astype(sign_type) + qzeros = torch.from_numpy(qzeros) + qzeros = qzeros.to("cuda") + out_qzeros.data.copy_(qzeros) + + return out_qweight, out_qscales, out_qzeros + +def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize): + layers = find_layers(model) + layers = {n: layers[n] for n in quantizers} + with torch.no_grad(): + for name in layers: + _, scale, zero, g_idx = quantizers[name] + qweight, qscales, qzeros = cai_linear_pack(layers[name], scale, zero, + qweight, qscales, qzeros, g_idx, + layers[name].weight.shape[-1], groupsize, wbits) + + # print("cai pack", layers) + return qweight, qscales, qzeros +if __name__ == "__main__": + + + parser = ArgumentParser() + parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.') + parser.add_argument('--wbits', type=int, default=4, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.') + parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.') + parser.add_argument('--nsamples', type=int, default=1, help='Number of calibration data samples.') + parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.') + parser.add_argument('--groupsize', type=int, default=128, help='Groupsize to use for quantization; default uses full row.') + parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic') + args = parser.parse_args() + infeature = 5120 + outfeature = 5120 + + weight = torch.randn(outfeature, infeature).to(torch.float16).to(torch.cuda.current_device()) + bias = torch.zeros(outfeature).to(torch.float16).to(torch.cuda.current_device()) + wn = 16 + ptype = torch.int64 + + # wn = 8 + # ptype = torch.int32 + + qweight = torch.zeros(infeature//wn, outfeature, dtype=ptype, device=torch.cuda.current_device()).contiguous() + qscales = torch.zeros(infeature//args.groupsize, outfeature, dtype=torch.float16, device=torch.cuda.current_device()).contiguous() + qzeros = torch.zeros(infeature//args.groupsize, outfeature//wn, dtype=ptype, device=torch.cuda.current_device()).contiguous() + + # print(linear.linear.weight) + act_func = nn.SiLU() + + inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device()) + batch_inps = torch.randn(1, 4096, infeature).to(torch.float16).to(torch.cuda.current_device()) + + linear = MLinear(infeature, outfeature) + linear.to(torch.cuda.current_device()) + with torch.no_grad(): + linear.linear.weight.data.copy_(weight) + linear.linear.bias.data.copy_(bias) + + with torch.no_grad(): + torch_out = linear(inps) + batch_torch_out = linear(batch_inps) + torch_out = act_func(torch_out) + batch_torch_out = act_func(batch_torch_out) + print("batch_torch out ", batch_torch_out) + + linear.to("cpu") + quantizers = model_quant(linear, inps, torch.cuda.current_device(), args) + qweight, qscales, qzeros = model_cai_pack(linear, quantizers, qweight, qscales, qzeros, args.wbits, args.groupsize) + gptq_model = model_pack(linear, quantizers, args.wbits, args.groupsize) + gptq_model.to(torch.cuda.current_device()) + # gptq_model = linear + + cai_inf_config = CaiInferenceConfig(fp16=True) + cai_linear = CaiGPTQLinearOp(cai_inf_config) + + + # qweight = torch.cat((qweight, qweight, qweight), dim=0).contiguous() + # qscales = torch.cat((qscales, qscales, qscales), dim=0).contiguous() + # qzeros = torch.cat((qzeros, qzeros, qzeros), dim=0).contiguous() + # bias = torch.cat((bias, bias, bias), dim=0).contiguous() + qkv_fused=False + # inps[:, :, 256:] = 0 + + with torch.no_grad(): + gptq_out = gptq_model(inps) + batch_gptq_out = gptq_model(batch_inps) + cai_out = cai_linear(inps, + qweight, + qscales, + qzeros, + bias = bias, + act_type = 3, + qkv_fused=qkv_fused) + torch.cuda.synchronize() + + batch_cai_out = cai_linear(batch_inps, + qweight, + qscales, + qzeros, + bias=bias, + act_type = 3, + qkv_fused=qkv_fused) + torch.cuda.synchronize() + batch_gptq_out = act_func(batch_gptq_out) + gptq_out = act_func(gptq_out) + + + # cai_out = cai_out[1] + # batch_cai_out = batch_cai_out[1] + a = torch.sum(qscales, 0) + print("qscales ", a) + print("orch out ", torch_out) + print("gptq out ", gptq_out) + print("cai out ", cai_out) + + print("batch_torch out ", batch_torch_out) + print("batch_gptq out ", batch_gptq_out) + print("batch_cai out ", batch_cai_out) + + mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) + max_diff = torch.max(torch.abs(cai_out - gptq_out)) + print("cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + mean_diff = torch.mean(torch.abs(torch_out - gptq_out)) + max_diff = torch.max(torch.abs(torch_out - gptq_out)) + print("torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + mean_diff = torch.mean(torch.abs(torch_out - cai_out)) + max_diff = torch.max(torch.abs(torch_out - cai_out)) + print("torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + + mean_diff = torch.mean(torch.abs(batch_cai_out - batch_gptq_out)) + max_diff = torch.max(torch.abs(batch_cai_out - batch_gptq_out)) + print("batch cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + mean_diff = torch.mean(torch.abs(batch_torch_out - batch_gptq_out)) + max_diff = torch.max(torch.abs(batch_torch_out - batch_gptq_out)) + print("batch torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + mean_diff = torch.mean(torch.abs(batch_torch_out - batch_cai_out)) + max_diff = torch.max(torch.abs(batch_torch_out - batch_cai_out)) + print("batch torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + + # print("torch time: {:.8f}, gptq time:{:.8f}, cai time: {:.8f} ".format(torch_linear_time/benchmark_iter, gptq_linear_time/benchmark_iter, cai_linear_time/benchmark_iter)) + # print("torch time: {:.8f}, gptq time:{:.8f}, cai time: {:.8f} ".format(torch_linear_time/benchmark_iter, gptq_linear_time/benchmark_iter, cai_linear_time/benchmark_iter)) + + + + + + # linear = MLinear(infeature, outfeature) + # linear.to(torch.cuda.current_device()) + # with torch.no_grad(): + # linear.linear.weight.data.copy_(weight) + # linear.linear.bias.data.copy_(bias) + # inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device()) + # quantizers = model_quant(linear, inps, torch.cuda.current_device(), args) + # qweight, qscales, qzeros = model_cai_pack(linear, quantizers, qweight, qscales, qzeros, args.wbits, args.groupsize) + # cai_inf_config = CaiInferenceConfig(fp16=True, device=torch.cuda.current_device()) + + # cai_linear = GPTQActLinearOp(cai_inf_config) + + # batch_inps = torch.randn(1, 4, infeature).to(torch.float16).to(torch.cuda.current_device()) + + # relu = nn.ReLU() + # # act_inps = relu(inps) + # # act_batch_inps = relu(batch_inps) + # # batch_inps = torch.ones(1, 2, infeature).to(torch.float16).to(torch.cuda.current_device()) + # # inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device()) + # # gptq_out = relu(inps) + # linear.to("cuda") + # with torch.no_grad(): + # torch_out = linear(inps) + # torch_batch_out = linear(batch_inps) + # # torch_out = relu(torch_out) + # # torch_batch_out = relu(torch_batch_out) + + # linear.to("cpu") + + # gptq_model = model_pack(linear, quantizers, args.wbits, args.groupsize) + # gptq_model.to(torch.cuda.current_device()) + + + # with torch.no_grad(): + # gptq_out = gptq_model(inps) + # cai_out = cai_linear(inps, + # qweight, + # qscales, + # qzeros, + # act_type = 1, + # bias = bias) + # gptq_batch_out = gptq_model(batch_inps) + # cai_batch_out = cai_linear(batch_inps, + # qweight, + # qscales, + # qzeros, + # act_type = 1, + # bias = bias) + + # torch.cuda.synchronize() + # # gptq_out = relu(gptq_out) + # # re_gptq_batch_out = relu(gptq_batch_out) + # # print(f"cai out {cai_out}") + # # mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) + # # max_diff = torch.max(torch.abs(cai_out - gptq_out)) + # # print("cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + # # mean_diff = torch.mean(torch.abs(torch_out - gptq_out)) + # # max_diff = torch.max(torch.abs(torch_out - gptq_out)) + # # print("torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + # # mean_diff = torch.mean(torch.abs(torch_out - cai_out)) + # # max_diff = torch.max(torch.abs(torch_out - cai_out)) + # # print("torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + + # print(f"cai batch out {cai_batch_out}") + # print(f"gptq batch out {gptq_batch_out}") + # print(f"torch batch out {torch_batch_out}") + # # print(f"gptq batch out {re_gptq_batch_out}") + + # mean_diff = torch.mean(torch.abs(cai_batch_out - gptq_batch_out)) + # max_diff = torch.max(torch.abs(cai_batch_out - gptq_batch_out)) + # print("cai vs gptq batch 128: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + # mean_diff = torch.mean(torch.abs(torch_batch_out - gptq_batch_out)) + # max_diff = torch.max(torch.abs(torch_batch_out - gptq_batch_out)) + # print("torch vs gptq batch 128: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + # mean_diff = torch.mean(torch.abs(torch_batch_out - cai_batch_out)) + # max_diff = torch.max(torch.abs(torch_batch_out - cai_batch_out)) + # print("torch vs cai batch 128: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + + diff --git a/tests/test_gptq/test_quant_llama.py b/tests/test_gptq/test_quant_llama.py new file mode 100644 index 000000000000..9f73a116f5bf --- /dev/null +++ b/tests/test_gptq/test_quant_llama.py @@ -0,0 +1,530 @@ +import argparse +import time +import numpy as np +import torch +import torch.nn as nn +from colossalai.gptq.gptq_utils import quant +from colossalai.gptq import cai_gptq + +from colossalai.gptq.gptq_utils import GPTQ, Observer +from colossalai.gptq.gptq_utils.utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders, export_quant_table, gen_conditions +from texttable import Texttable +from colossalai.gptq import CaiInferenceConfig +from transformers import LlamaForCausalLM, LlamaTokenizer + +import csv + +def get_llama(model): + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + from transformers import LlamaForCausalLM, LlamaConfig, LlamaModel + if args.debug: + llama_kwargs= {"bos_token_id": 0, + "eos_token_id": 1, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 2048, + "max_sequence_length": 2048, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "pad_token_id": -1, + "rms_norm_eps": 1e-06, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "use_cache": True, + "vocab_size": 32000 + } + configuration = LlamaConfig( **llama_kwargs + ) + model = LlamaForCausalLM(configuration) + else: + model = LlamaForCausalLM.from_pretrained(model, torch_dtype=torch.float16) + + # # LlamaForCausalLM + model.seqlen = 2048 + return model + + +@torch.no_grad() +def llama_sequential(model, dataloader, dev): + print('Starting ...') + + use_cache = model.config.use_cache + model.config.use_cache = False + layers = model.model.layers + + model.model.embed_tokens = model.model.embed_tokens.to(dev) + model.model.norm = model.model.norm.to(dev) + layers[0] = layers[0].to(dev) + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) + cache = {'i': 0, 'attention_mask': None} + + class Catcher(nn.Module): + + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, inp, **kwargs): + inps[cache['i']] = inp + cache['i'] += 1 + cache['attention_mask'] = kwargs['attention_mask'] + cache['position_ids'] = kwargs['position_ids'] + raise ValueError + + layers[0] = Catcher(layers[0]) + for batch in dataloader: + try: + model(batch[0].to(dev)) + except ValueError: + pass + layers[0] = layers[0].module + + layers[0] = layers[0].cpu() + model.model.embed_tokens = model.model.embed_tokens.cpu() + model.model.norm = model.model.norm.cpu() + torch.cuda.empty_cache() + + outs = torch.zeros_like(inps) + attention_mask = cache['attention_mask'] + position_ids = cache['position_ids'] + + print('Ready.') + + quantizers = {} + observer = Observer() + for i in range(len(layers)): + + print(f'Quantizing layer {i+1}/{len(layers)}..') + print('+------------------+--------------+------------+-----------+-------+') + print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |') + print('+==================+==============+============+===========+=======+') + + layer = layers[i].to(dev) + full = find_layers(layer) + if args.true_sequential: + sequential = [['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], ['self_attn.o_proj'], ['mlp.up_proj', 'mlp.gate_proj'], ['mlp.down_proj']] + else: + sequential = [list(full.keys())] + + for names in sequential: + subset = {n: full[n] for n in names} + gptq = {} + for name in subset: + gptq[name] = GPTQ(subset[name], observe=args.observe) + gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False) + + def add_batch(name): + + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + + return tmp + + handles = [] + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] + for h in handles: + h.remove() + + for name in subset: + scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, name=name) + quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), args.wbits, args.groupsize) + + if args.observe: + observer.submit(name=name, layerid=i, gptq=gptq[name], error=error) + else: + gptq[name].free() + + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] + + layers[i] = layer.cpu() + del layer + del gptq + torch.cuda.empty_cache() + + inps, outs = outs, inps + print('+------------------+--------------+------------+-----------+-------+') + print('\n') + + if args.observe: + observer.print() + conditions = gen_conditions(args.wbits, args.groupsize) + for item in observer.items(): + name = item[0] + layerid = item[1] + gptq = item[2]['gptq'] + error = item[2]['error'] + target = error / 2 + + table = Texttable() + table.header(['wbits', 'groupsize', 'error']) + table.set_cols_dtype(['i', 'i', 'f']) + table.add_row([args.wbits, args.groupsize, error]) + + print('Optimizing {} {} ..'.format(name, layerid)) + for wbits, groupsize in conditions: + + if error < target: + # if error dropped 50%, skip + break + + gptq.quantizer.configure(wbits, perchannel=True, sym=args.sym, mse=False) + + scale, zero, g_idx, error = gptq.fasterquant(percdamp=args.percdamp, groupsize=groupsize, actorder=args.act_order, name=name) + + table.add_row([wbits, groupsize, error]) + quantizers['model.layers.%d.%s' % (layerid, name)] = (gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), wbits, groupsize) + + print(table.draw()) + print('\n') + gptq.layer.to('cpu') + gptq.free() + + model.config.use_cache = use_cache + + return quantizers + + +# TODO: perform packing on GPU +def cai_llama_pack(model, quantizers, wbits, groupsize): + layers = find_layers(model) + # print(f"model {model}") + # print(f"layers {layers}") + + layers = {n: layers[n] for n in quantizers} + # print(f"quantizers {quantizers}") + cai_gptq.make_cai_quant_linear(model, quantizers, wbits, groupsize) + qlayers = find_layers(model, [cai_gptq.CaiQuantLinear]) + print('Packing ...') + for name in qlayers: + print(name) + quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + print('Done.') + return model + +def gptq_llama_pack(model, quantizers, wbits, groupsize): + layers = find_layers(model) + # print(f"model {model}") + # print(f"layers {layers}") + + layers = {n: layers[n] for n in quantizers} + # print(f"quantizers {quantizers}") + quant.make_quant_linear(model, quantizers, wbits, groupsize) + qlayers = find_layers(model, [quant.QuantLinear]) + print('Packing ...') + for name in qlayers: + print(name) + quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + print('Done.') + return model + + +def cai_load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True): + from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils + config = LlamaConfig.from_pretrained(model) + + def noop(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = noop + torch.nn.init.uniform_ = noop + torch.nn.init.normal_ = noop + + torch.set_default_dtype(torch.half) + modeling_utils._init_weights = False + torch.set_default_dtype(torch.half) + model = LlamaForCausalLM(config) + torch.set_default_dtype(torch.float) + if eval: + model = model.eval() + layers = find_layers(model) + for name in ['lm_head']: + if name in layers: + del layers[name] + cai_gptq.make_cai_quant_linear(model, layers, wbits, groupsize) + + del layers + + print('Loading model ...') + if checkpoint.endswith('.safetensors'): + from safetensors.torch import load_file as safe_load + model.load_state_dict(safe_load(checkpoint)) + else: + model.load_state_dict(torch.load(checkpoint)) + + print('Done.') + + return model + + +def gptq_load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True): + from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils + config = LlamaConfig.from_pretrained(model) + + def noop(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = noop + torch.nn.init.uniform_ = noop + torch.nn.init.normal_ = noop + + torch.set_default_dtype(torch.half) + modeling_utils._init_weights = False + torch.set_default_dtype(torch.half) + model = LlamaForCausalLM(config) + torch.set_default_dtype(torch.float) + if eval: + model = model.eval() + layers = find_layers(model) + for name in ['lm_head']: + if name in layers: + del layers[name] + quant.make_quant_linear(model, layers, wbits, groupsize) + + del layers + + print('Loading model ...') + if checkpoint.endswith('.safetensors'): + from safetensors.torch import load_file as safe_load + model.load_state_dict(safe_load(checkpoint)) + else: + model.load_state_dict(torch.load(checkpoint)) + + print('Done.') + + return model + +all_perfs = [] +now_perf=[] + +def print_perf_stats(latency_set, config, warmup=3): + global now_perf + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = getattr(config, "num_layers", config.num_hidden_layers) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 + # if args.dtype == "float16": + # num_bytes = 2 + # elif args.dtype == "float32": + # num_bytes = 4 + # else: + # num_bytes = 1 + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1/avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1/avg * num_parameters * num_bytes * args.batch_size / 1e12)) + print("Alloc GPU Mem: {0:8.2f} GB".format(torch.cuda.memory_allocated() / 1e9)) + print("Max alloc GPU Mem: {0:8.2f} GB".format(torch.cuda.max_memory_allocated()/1e9)) + row = [args.batch_size, args.input_len, args.max_new_tokens, "{0:8.2f}".format(avg * 1000), + "{0:8.2f}".format(torch.cuda.memory_allocated() / 1e9), + "{0:8.2f}".format(torch.cuda.max_memory_allocated()/1e9)] + with open('./{}_profile.csv'.format(args.model_type), 'a', encoding='UTF8') as f: + # create the csv writer + writer = csv.writer(f) + + # write a row to the csv file + writer.writerow(row) + + now_perf.append("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + now_perf.append("Alloc GPU Mem: {0:8.2f} GB".format(torch.cuda.memory_allocated() / 1e9)) + now_perf.append("Max alloc GPU Mem: {0:8.2f} GB".format(torch.cuda.max_memory_allocated()/1e9)) + + all_perfs.append(now_perf) + now_perf = [] + +def benchmark(model): + + input_tokens = {"input_ids":torch.randint(1, 1000, (args.batch_size, args.input_len), device=DEV), + "attention_mask":torch.ones((args.batch_size, args.input_len), device=DEV)} + torch.cuda.synchronize() + iters = 10 if args.benchmark else 2 #warmup + print(f"model config {model.config}") + + times = [] + warmup=3 + prof_flag = 0 + generate_kwargs = dict(max_new_tokens=args.max_new_tokens, do_sample=False) + torch.cuda.reset_peak_memory_stats() + for i in range(iters): + if i >= warmup: + prof_flag=1 + torch.cuda.synchronize() + start = time.time() + outputs = model.generate(**input_tokens, + **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + times.append(end - start) + print("outpus shape: ", outputs.shape) + print(args) + print("input batch, input len, out len: ",args.batch_size, args.input_len, args.max_new_tokens) + # if args.local_rank == 0: + now_perf.extend(["input batch, input len, out len: ",args.batch_size, args.input_len, args.max_new_tokens]) + print_perf_stats(map(lambda t: t / args.max_new_tokens, times), model.config) + +def test(model_1, model_2): + # input_tokens = {"input_ids":torch.randint(1, 1000, (args.batch_size, args.input_len), device=DEV), + # "attention_mask":torch.ones((args.batch_size, args.input_len), device=DEV)} + generate_kwargs = dict(max_new_tokens=args.max_new_tokens, do_sample=False) + + + tokenizer = LlamaTokenizer.from_pretrained(args.model) + tokenizer.pad_token_id = tokenizer.unk_token_id + + text = "how is weather today? I want to know the weather of beijing. " + text = "how are you?" + + inputs = [text] + input_tokens = tokenizer.batch_encode_plus(inputs, padding = True, return_tensors="pt") + + input_len = 0 + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) + # print(input_tokens[t].shape) + input_len = input_tokens[t].shape[1] + + outputs_1 = model_1.generate(**input_tokens, + **generate_kwargs) + + + outputs_2 = model_2.generate(**input_tokens, + **generate_kwargs) + + out_1 = tokenizer.batch_decode(outputs_1) + out_2 = tokenizer.batch_decode(outputs_2) + + ret = torch.allclose(outputs_1, outputs_2) + print("allclose is ", ret) + print("decode out:", out_1) + print("decode out:", out_2) + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + + parser.add_argument('model', type=str, help='llama model to load') + parser.add_argument('dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.') + parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.') + parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration data samples.') + parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.') + parser.add_argument('--nearest', action='store_true', help='Whether to run the RTN baseline.') + parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.') + parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.') + parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.') + parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.') + parser.add_argument('--save_safetensors', type=str, default='', help='Save quantized `.safetensors` checkpoint under this name.') + parser.add_argument('--load', type=str, default='', help='Load quantized model.') + parser.add_argument('--benchmark', action='store_true', help='Number of tokens to use for benchmarking.') + parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.') + parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.') + parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic') + parser.add_argument('--true-sequential', action='store_true', help='Whether to run in true sequential model.') + parser.add_argument('--layers-dist', type=str, default='', help='Distribution of layers across GPUs. e.g. 2:1:1 for 2 layers on GPU 0, 1 layer on GPU 1, and 1 layer on GPU 2. Any remaining layers will be assigned to your last GPU.') + parser.add_argument('--observe', + action='store_true', + help='Auto upgrade layer precision to higher precision, for example int2 to int4, groupsize 128 to 64. \ + When this feature enabled, `--save` or `--save_safetensors` would be disable.') + parser.add_argument('--quant-directory', type=str, default=None, help='Specify the directory for export quantization parameters to toml format. `None` means no export by default.') + parser.add_argument('--max_new_tokens', type=int, default=32, help='Max new tokens to generate.') + parser.add_argument('--batch_size', type=int, default=1, help='Batch size to generate.') + parser.add_argument('--input_len', type=int, default=128, help='Batch size to generate.') + parser.add_argument('--model_type', type=str, choices=['cai', 'gptq', 'torch'], default='torch', help='Batch size to generate.') + parser.add_argument('--debug', action='store_true', help='Whether to debug or not') + + args = parser.parse_args() + + model_packed = False + if type(args.load) is not str: + args.load = args.load.as_posix() + + if args.load: + if args.model_type == "gptq": + model = gptq_load_quant(args.model, args.load, args.wbits, args.groupsize) + elif args.model_type == "cai": + model = cai_load_quant(args.model, args.load, args.wbits, args.groupsize) + else: + model = get_llama(args.model) + model.half() + + if not args.load and args.wbits < 16 and not args.nearest and args.model_type in ['cai', 'gptq']: + dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen) + tick = time.time() + quantizers = llama_sequential(model, dataloader, DEV) + if args.model_type == "cai": + cai_llama_pack(model, quantizers, args.wbits, args.groupsize) + elif args.model_type == "gptq": + gptq_llama_pack(model, quantizers, args.wbits, args.groupsize) + model_packed = True + print(time.time() - tick) + + + if args.quant_directory is not None: + export_quant_table(quantizers, args.quant_directory) + + if not args.observe and args.save and args.model_type in ['cai', 'gptq']: + if not model_packed: + llama_pack(model, quantizers, args.wbits, args.groupsize) + model_packed = True + torch.save(model.state_dict(), args.save) + + if not args.observe and args.save_safetensors and args.model_type in ['cai', 'gptq']: + if not model_packed: + llama_pack(model, quantizers, args.wbits, args.groupsize) + from safetensors.torch import save_file as safe_save + state_dict = model.state_dict() + state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} + safe_save(state_dict, args.save_safetensors) + + if args.benchmark: + model = model.to(DEV) + # print(f"model config {model.config.num_key_value_heads}") + + # if args.model_type == "cai": + # cai_inf_config = CaiInferenceConfig(fp16=True, + # device=torch.cuda.current_device(), + # gptq=True, + # gptq_group_size=128, + # gptq_quant_bits=4) + # model = convert_to_ds_model(model, cai_inf_config) + # model.cuda().to(torch.cuda.current_device()) + + + torch_model = get_llama(args.model) + torch_model.half() + torch_model = torch_model.to(DEV) + + test(torch_model, model) + + # # for batch in [1, 2, 4, 8, 16, 32]: + # for batch in [1]: + # args.batch_size = batch + # # for in_len in [128, 256, 512, 1024, 2048]: + # for in_len in [1024]: + # args.input_len = in_len + # benchmark(model) + # # for info in all_perfs: + # # print(info) + # # # all_perfs = [] \ No newline at end of file From 402452507e1ed02b6820dd88f600375d2e606412 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Fri, 11 Aug 2023 16:43:03 +0800 Subject: [PATCH 02/15] refactor code --- colossalai/gptq/__init__.py | 4 +- colossalai/gptq/cai_gptq/cai_quant_linear.py | 14 +- colossalai/gptq/cai_gptq/gptq_autotune.py | 167 ---- colossalai/gptq/cai_gptq/gptq_op.py | 82 +- colossalai/gptq/cai_gptq/gptq_triton.py | 317 +------- colossalai/gptq/config.py | 36 - colossalai/gptq/csrc/gptq_act_linear.cu | 387 --------- .../gptq/csrc/includes/conversion_utils.h | 641 --------------- .../gptq/csrc/includes/ds_kernel_utils.h | 52 -- .../csrc/includes/inference_cuda_layers.h | 32 - colossalai/gptq/csrc/pt_binding.cpp | 23 - colossalai/gptq/inference_builder.py | 761 ------------------ tests/test_gptq/linear_act_fusion_bench.py | 4 +- tests/test_gptq/quant_llama.py | 93 +-- tests/test_gptq/run_gptq.sh | 6 +- tests/test_gptq/test_linear_act_fusion.py | 94 +-- 16 files changed, 70 insertions(+), 2643 deletions(-) delete mode 100644 colossalai/gptq/cai_gptq/gptq_autotune.py delete mode 100644 colossalai/gptq/config.py delete mode 100644 colossalai/gptq/csrc/gptq_act_linear.cu delete mode 100644 colossalai/gptq/csrc/includes/conversion_utils.h delete mode 100644 colossalai/gptq/csrc/includes/ds_kernel_utils.h delete mode 100644 colossalai/gptq/csrc/includes/inference_cuda_layers.h delete mode 100644 colossalai/gptq/csrc/pt_binding.cpp delete mode 100644 colossalai/gptq/inference_builder.py diff --git a/colossalai/gptq/__init__.py b/colossalai/gptq/__init__.py index 55b7f3c85b2d..b28b04f64312 100644 --- a/colossalai/gptq/__init__.py +++ b/colossalai/gptq/__init__.py @@ -1,5 +1,3 @@ -from .config import CaiInferenceConfig -from .inference_builder import InferenceBuilder -from torch import nn + diff --git a/colossalai/gptq/cai_gptq/cai_quant_linear.py b/colossalai/gptq/cai_gptq/cai_quant_linear.py index f6ba8ab0394b..72a8e6d5607c 100644 --- a/colossalai/gptq/cai_gptq/cai_quant_linear.py +++ b/colossalai/gptq/cai_gptq/cai_quant_linear.py @@ -5,8 +5,6 @@ import torch.nn as nn from torch.cuda.amp import custom_bwd, custom_fwd from .gptq_op import CaiGPTQLinearOp -from ..config import CaiInferenceConfig -from .gptq_triton import gptq_linear_llama import triton class CaiQuantLinear(nn.Module): @@ -34,9 +32,7 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias): else: self.bias = None - cai_inf_config = CaiInferenceConfig(fp16=True, - gptq_group_size=self.groupsize) - self.gptq_linear = CaiGPTQLinearOp(cai_inf_config) + self.gptq_linear = CaiGPTQLinearOp(groupsize, bits) self.printed = False self.reorder_zeros = False def pack(self, linear, scales, zeros, g_idx=None): @@ -113,19 +109,11 @@ def pack(self, linear, scales, zeros, g_idx=None): def forward(self, x): - # if self.reorder_zeros == False: - # for i in range(self.g_idx.shape[0]): - # idx = self.g_idx[i] - # self.order_qzeros[i,:] = self.qzeros[idx,:] - # gptq_out = gptq_linear_llama(x, self.qweight, self.scales, self.qzeros, self.g_idx, - # self.bits, self.maxq) - cai_out = self.gptq_linear(x, self.qweight, self.scales, self.qzeros, bias = self.bias) - print("shape is ", cai_out.shape) return cai_out def make_cai_quant_linear(module, names, bits, groupsize, name=''): diff --git a/colossalai/gptq/cai_gptq/gptq_autotune.py b/colossalai/gptq/cai_gptq/gptq_autotune.py deleted file mode 100644 index 18e9969d8d00..000000000000 --- a/colossalai/gptq/cai_gptq/gptq_autotune.py +++ /dev/null @@ -1,167 +0,0 @@ -import math -import time -import torch - -class AutoTune: - def __init__(self, tune_func, warmup=10, bech_run=20): - - self.func = tune_func - self.warmup_num = warmup - self.bech_run = bech_run - self.config_caches = {} - - def prune_configs(self, tune_config): - - norm_configs = [] - linear_configs = [] - - if tune_config['qkv_fused']: - max_in = 2**int(math.log2(tune_config['in_dim'])) - in_dim = tune_config['in_dim'] - else: - max_in = 2**int(math.log2(tune_config['in_dim'])) - in_dim = tune_config['in_dim'] - - max_out = 2**int(math.log2(tune_config['out_dim'])) - if max_out > 1024: - max_out = 1024 - m = 2 - n = 64 - x = 64 - y = 1 - - # while n <= max_out: - # ret_config = { - # "linear_x": n, - # "linear_y": in_dim, - # } - # linear_configs.append(ret_config) - # n = n * 2 - while n <= max_out: - m = 64 - while m < in_dim: - ret_config = { - "linear_x": n, - "linear_y": m, - } - linear_configs.append(ret_config) - m = m * 2 - ret_config = { - "linear_x": n, - "linear_y": in_dim, - } - linear_configs.append(ret_config) - n = n * 2 - - # if tune_config['act_type'] == 0: - # while n <= max_out: - # m = 64 - # while m <= max_in: - # ret_config = { - # "norm_x": 0, - # "norm_y": 0, - # "linear_x": n, - # "linear_y": m, - # } - # linear_configs.append(ret_config) - # m = m * 2 - # n = n * 2 - # elif tune_config['act_type'] > 0: - # while n <= max_out: - # ret_config = { - # "norm_x": 0, - # "norm_y": 0, - # "linear_x": n, - # "linear_y": in_dim, - # } - # linear_configs.append(ret_config) - # n = n * 2 - return linear_configs - - def warmup(self, tune_config, *args, **kwargs): - # if tune_config['qkv_fused']: - # output = torch.zeros(3, tune_config['input_len'], tune_config['out_dim'], - # dtype = torch.float16, device=torch.cuda.current_device()).contiguous() - # else: - # output = torch.zeros(tune_config['input_len'], tune_config['out_dim'], - # dtype = torch.float16, device=torch.cuda.current_device()).contiguous() - - for i in range(0, self.warmup_num): - # out = self.func(*args[:6], output, *args[7:],**kwargs) - out = self.func(*args, **kwargs) - - def benchmark(self, tune_config, *args, **kwargs): - - self.warmup(tune_config, *args, **kwargs) - linear_configs = self.prune_configs(tune_config) - # print(ret_configs) - times = {} - best_norm_x = 512 - best_norm_y = 1 - best_linear_x = 512 - best_linear_y = 256 - # if best_linear_x > tune_config['out_dim']: - # best_linear_x = 2**int(math.log2(tune_config['out_dim'])) - # if best_linear_y > tune_config['in_dim'] and tune_config['qkv_fused'] == False: - # best_linear_y = 2**int(math.log2(tune_config['in_dim'])) - # if best_linear_y > tune_config['in_dim'] and tune_config['qkv_fused']: - # best_linear_y = 2**int(math.log2(tune_config['in_dim'] //3)) - if best_norm_x > tune_config['input_dim']: - best_norm_x = 2**int(math.log2(tune_config['input_dim'])) - - if tune_config['wdtype'] == torch.int8: - nweights = 2 - elif tune_config['wdtype'] == torch.int32: - nweights = 8 - elif tune_config['wdtype'] == torch.int64: - nweights = 16 - times = {} - - # if tune_config['qkv_fused']: - # output = torch.zeros(3, tune_config['input_len'], tune_config['out_dim'], - # dtype = torch.float16, device=torch.cuda.current_device()).contiguous() - # else: - # output = torch.zeros(tune_config['input_len'], tune_config['out_dim'], - # dtype = torch.float16, device=torch.cuda.current_device()).contiguous() - - for config in linear_configs: - linear_x = config['linear_x'] - linear_y = config['linear_y'] - print(config) - start = time.time() - for run in range(0, self.bech_run): - # out = self.func(*args[:6], output, *args[7:-2], linear_x, linear_y) - out = self.func(*args[:-2], linear_x, linear_y) - - torch.cuda.synchronize() - end = time.time() - times[' '.join(map(str,config.values()))] = end - start - # print(f"{config}: {end-start:.6f}") - sorted_dict = sorted(times.items(), key=lambda x:x[1]) - values = sorted_dict[0][0].split() - # print(sorted_dict) - best_linear_x = int(values[0]) - best_linear_y = int(values[1]) - - times = {} - - key = ' '.join(map(str,tune_config.values())) - - ret_config = { - "linear_x": best_linear_x, - "linear_y": best_linear_y, - } - self.config_caches[key] = ret_config - # print("best config:", tune_config, ret_config) - def get_best_config(self, tune_config, *args, **kwargs): - key = ' '.join(map(str,tune_config.values())) - - if key in self.config_caches: - return self.config_caches[key] - else: - # print(tune_config) - self.benchmark(tune_config, *args, **kwargs) - return self.config_caches[key] - - - diff --git a/colossalai/gptq/cai_gptq/gptq_op.py b/colossalai/gptq/cai_gptq/gptq_op.py index 13af9953395c..7ada87055a97 100644 --- a/colossalai/gptq/cai_gptq/gptq_op.py +++ b/colossalai/gptq/cai_gptq/gptq_op.py @@ -1,32 +1,15 @@ - -from ..config import CaiInferenceConfig -from ..inference_builder import inference_cuda_module -from .gptq_autotune import AutoTune from .gptq_triton import gptq_fused_linear_triton import torch -class BaseOp(torch.nn.Module): - inference_cuda_module = inference_cuda_module - def __init__(self, config: CaiInferenceConfig): - super(BaseOp, self).__init__() - self.config = config - if BaseOp.inference_cuda_module is None: - BaseOp.inference_cuda_module = inference_cuda_module - - -class CaiGPTQLinearOp(BaseOp): - autotune = None - def __init__(self, config: CaiInferenceConfig): - super(CaiGPTQLinearOp, self).__init__(config) - self.linear_func = self.inference_cuda_module.gptq_act_linear_fp16 +class CaiGPTQLinearOp(torch.nn.Module): - self.group_size = config.gptq_group_size - self.bits = config.gptq_quant_bits + def __init__(self, gptq_group_size, gptq_quant_bits): + super(CaiGPTQLinearOp, self).__init__() + self.group_size = gptq_group_size + self.bits = gptq_quant_bits self.maxq = 2**self.bits - 1 self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device()) - if CaiGPTQLinearOp.autotune == None: - CaiGPTQLinearOp.autotune = AutoTune(self.linear_func) def forward(self, input: torch.Tensor, @@ -37,6 +20,7 @@ def forward(self, bias: torch.Tensor = None, residual: torch.Tensor=None, qkv_fused = False): + add_bias = True if bias is None: bias = self.empty_tensor @@ -48,54 +32,12 @@ def forward(self, add_residual = False x = input.view(-1, input.shape[-1]) - if x.shape[0] > 1: - out = gptq_fused_linear_triton(x, weight, weight_scales, weight_zeros, bias, residual, - self.bits, self.maxq, self.group_size, qkv_fused, add_bias, add_residual, act_type=act_type) - if qkv_fused: - out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1]) - else: - out = out.view(input.shape[0], input.shape[1], weight.shape[-1]) + + out = gptq_fused_linear_triton(x, weight, weight_scales, weight_zeros, bias, residual, + self.bits, self.maxq, self.group_size, qkv_fused, add_bias, add_residual, act_type=act_type) + if qkv_fused: + out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1]) else: - print("inut shape, ", input.shape) - config = { - "input_dim": input.shape[-1], - "input_len": input.shape[0] * input.shape[1] , - "add_bias": add_bias, - "add_residual": add_residual, - "qkv_fused": qkv_fused, - "act_type": act_type, - "out_dim": weight.shape[-1], - "in_dim": input.shape[-1], - "wdtype": weight.dtype - } + out = out.view(input.shape[0], input.shape[1], weight.shape[-1]) - best_config = CaiGPTQLinearOp.autotune.get_best_config(config, - input, - weight, - weight_scales, - weight_zeros, - bias, - residual, - self.group_size, - act_type, - add_bias, - add_residual, - qkv_fused, - 128, - 128) - block_size_x = best_config['linear_x'] - block_size_y = best_config['linear_y'] - out = self.linear_func(input, - weight, - weight_scales, - weight_zeros, - bias, - residual, - self.group_size, - act_type, - add_bias, - add_residual, - qkv_fused, - block_size_x, - block_size_y) return out \ No newline at end of file diff --git a/colossalai/gptq/cai_gptq/gptq_triton.py b/colossalai/gptq/cai_gptq/gptq_triton.py index 44ec38fd80b5..eb77987ef625 100644 --- a/colossalai/gptq/cai_gptq/gptq_triton.py +++ b/colossalai/gptq/cai_gptq/gptq_triton.py @@ -99,173 +99,6 @@ def smelu(x): def silu(x): return x*tl.sigmoid(x) -@custom_autotune.autotune( - configs=[ - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=2, num_warps=8), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 8 - }, num_stages=3, num_warps=8), - triton.Config({ - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 128, - 'GROUP_SIZE_M': 8 - }, num_stages=2, num_warps=4), - ], - key=['M', 'N', 'K'], - nearest_power_of_two=True, - prune_configs_by={ - 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, - 'perf_model': None, - 'top_k': None, - }, -) -@triton.jit -def cai_gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_ptr, residual_ptr, - M, N, K, bits, maxq, gptq_group_size, - stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, - QKV_FUSED: tl.constexpr, ADD_BIAS: tl.constexpr, ADD_RESIDUAL:tl.constexpr, ACT_TYPE:tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (K//16, N) int64 - C is of shape (M, N) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N) float16 - """ - infearure_per_bits = 64 // bits - - pid = tl.program_id(axis=0) - NK = K - - # if QKV_FUSED: - # NK = K//3 - # else: - # NK = K - # NK = K - - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(NK, BLOCK_SIZE_K) - qkv_offset = pid // (num_pid_m * num_pid_n) - pid = pid % (num_pid_m * num_pid_n) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - # offs_bk = offs_k + qkv_offset * NK - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - - a_mask = (offs_am[:, None] < M) - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + qkv_offset * N * NK //infearure_per_bits + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - # g_ptrs = g_ptr + offs_k - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + qkv_offset * NK * N //gptq_group_size + offs_bn[None, :] - zeros_ptrs = zeros_ptr + qkv_offset * NK * N //gptq_group_size//infearure_per_bits + (offs_bn[None, :] // infearure_per_bits) - - shifter = (offs_k % infearure_per_bits) * bits - zeros_shifter = (offs_bn % infearure_per_bits) * bits - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - g_idx_base = tl.arange(0, BLOCK_SIZE_K) - g_idx_base = g_idx_base // gptq_group_size - g_idx = g_idx_base - # tl.device_print("gidx, ", g_idx) - - currend_group_end = 0 - for k in range(0, num_pid_k): - # g_idx = tl.load(g_ptrs) - scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - - # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros).to(tl.float16) * scales # Scale and shift - accumulator += tl.dot(a, b) - - a_ptrs += BLOCK_SIZE_K - b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - g_idx = g_idx_base + ((k + 1) * BLOCK_SIZE_K) // gptq_group_size - # if (k + 2) * BLOCK_SIZE_K > currend_group_end: - - c_ptrs = c_ptr + qkv_offset * M * N + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) - - if ADD_BIAS: - bias_mask = (offs_bn < N) - offs_bn += qkv_offset * N - bias_ptrs = bias_ptr + stride_cn * offs_bn - bias = tl.load(bias_ptrs, mask=bias_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - accumulator += bias[None, :] - - - if ACT_TYPE == 1: - accumulator=relu(accumulator) - elif ACT_TYPE == 2: - accumulator=gelu(accumulator) - elif ACT_TYPE == 3: - accumulator=silu(accumulator) - - - if ADD_RESIDUAL: - residual_ptrs = residual_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - res = tl.load(residual_ptrs, mask=c_mask, other=0.) - accumulator += res - - tl.store(c_ptrs, accumulator, mask=c_mask) - - @custom_autotune.autotune( configs=[ @@ -327,7 +160,7 @@ def cai_gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_ }, ) @triton.jit -def cai_gptq_v2_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_ptr, residual_ptr, +def cai_gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_ptr, residual_ptr, M, N, K, bits, maxq, gptq_group_size, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, QKV_FUSED: tl.constexpr, ADD_BIAS: tl.constexpr, ADD_RESIDUAL:tl.constexpr, ACT_TYPE:tl.constexpr, @@ -616,7 +449,7 @@ def gptq_fused_linear_triton(input, qweight, scales, qzeros, bias, residual, output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) # print("dtype, ", qweight.dtype, output.dtype, scales.dtype, qzeros.dtype, bias.dtype, residual.dtype) if idx is None: - cai_gptq_v2_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, bias, residual, + cai_gptq_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, bias, residual, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, gptq_group_size, input.stride(0), input.stride(1), qweight.stride(0), @@ -633,149 +466,3 @@ def gptq_fused_linear_triton(input, qweight, scales, qzeros, bias, residual, return output.view(3, input.shape[0], qweight.shape[1]) else: return output - - - -# code based https://github.com/fpgaminer/GPTQ-triton -@custom_autotune.autotune( - configs=[ - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=2, num_warps=8), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 8 - }, num_stages=3, num_warps=8), - triton.Config({ - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 128, - 'GROUP_SIZE_M': 8 - }, num_stages=2, num_warps=4), - ], - key=['M', 'N', 'K'], - nearest_power_of_two=True, - prune_configs_by={ - 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, - 'perf_model': None, - 'top_k': None, - }, -) -@triton.jit -def gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (K//8, N) int32 - C is of shape (M, N) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 - """ - infearure_per_bits = 32 // bits - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = (offs_am[:, None] < M) - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - g_ptrs = g_ptr + offs_k - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + offs_bn[None, :] - zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) - - shifter = (offs_k % infearure_per_bits) * bits - zeros_shifter = (offs_bn % infearure_per_bits) * bits - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in range(0, num_pid_k): - g_idx = tl.load(g_ptrs) - - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) - - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - - # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros) * scales # Scale and shift - - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K - b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - g_ptrs += BLOCK_SIZE_K - - c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - - -def gptq_linear_llama(x, qweight, scales, qzeros, g_idx, - bits, maxq): - - out_shape = x.shape[:-1] + (qweight.shape[-1], ) - input = x.reshape(-1, x.shape[-1]) - # print("input shape:", input.shape)/ - with torch.cuda.device(input.device): - output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) - grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), ) - gptq_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], - qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), - qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) - # output = output.reshape(out_shape) - - return output.reshape(out_shape) \ No newline at end of file diff --git a/colossalai/gptq/config.py b/colossalai/gptq/config.py deleted file mode 100644 index fc2dd4cb9661..000000000000 --- a/colossalai/gptq/config.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team - -import json -import torch -from enum import IntEnum - -DEFAULT_INTERMEDIATE_SIZE = -1 -class ActivationFuncType(IntEnum): - UNKNOWN = 0 - ReLU = 1 - GELU = 2 - SiLU = 3 - GATED_GELU = 4 - GATED_SILU = 5 - - -class CaiInferenceConfig(): - - - def __init__(self, - fp16=True, - gptq=False, - gptq_group_size=128, - gptq_quant_bits=4, - gptq_weight_dtype=torch.int64 - ): - self.fp16 = fp16 - self.gptq = gptq - self.gptq_group_size = gptq_group_size - self.gptq_quant_bits = gptq_quant_bits - self.gptq_weight_dtype = gptq_weight_dtype - - diff --git a/colossalai/gptq/csrc/gptq_act_linear.cu b/colossalai/gptq/csrc/gptq_act_linear.cu deleted file mode 100644 index b622233df89e..000000000000 --- a/colossalai/gptq/csrc/gptq_act_linear.cu +++ /dev/null @@ -1,387 +0,0 @@ -#include "conversion_utils.h" -#include "inference_cuda_layers.h" -#include -#include -#include -#include -#include -#include -#include -#include -#define SHARE_MEM_SIZE (48 * 1024) -inline __device__ float relu(const float x) { return x < 0 ? 0 : x; } -inline __device__ float gelu(const float x) -{ - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); -} -inline __device__ float silu(const float x) -{ - return x / (1 + expf(-x)); -} -/*** -input: the input size is [b, l, m] -weight: the weight size is [m/size(TW)*2, n] -weight_scales: the weight scales size is [m/group_size, n] -weight_zeros: the weight scales size is [m/group_size, n/size(TW)*2] -bias: linear bias [n] -input_dim0: m -input_dim1: b * l -weight_dim0: n -weight_dim1: m -block_size_m: m for one gpu thread block -block_size_n: n for one gpu thread block -the computation block is [block_size_m, block_size_n] for one gpu thread block -group_size: the group size for gptq quant -add_bias: the linear has bias or not -***/ -template -__global__ void gptq_gemm(T* input, - TW* weight, - T* weight_scales, - TW* weight_zeros, - T* bias, - T* residual, - T* output, - uint64_t input_dim0, - uint64_t input_dim1, - uint64_t weight_dim0, - uint64_t weight_dim1, - uint64_t group_size, - int32_t act_type, - bool add_bias, - bool add_residual, - bool qkv_fused, - uint64_t block_size_m, - uint64_t block_size_n) -{ - const uint32_t n_weights = sizeof(TW) * 2; // number of compressed weights in a TW. - - uint64_t block_offset = blockIdx.x; - uint64_t local_tid = threadIdx.x; - uint64_t block_tnum = blockDim.x; - uint64_t block_m_start = blockIdx.y * block_size_m; - uint64_t block_m_end = (blockIdx.y + 1) * block_size_m; - block_m_end = std::min(block_m_end, weight_dim1 * n_weights); - - uint64_t group_step = 32; - group_step = group_step - group_step % n_weights; - group_step = std::min(group_step, group_size); - - uint64_t group_block = group_step / n_weights; - uint64_t table_iter = (group_step / 2 * 256) / block_tnum; - uint64_t col_offset = block_size_n * block_offset; - - __shared__ float table[16][256]; // look-up table, for 32 inputs - __shared__ float i2sum[16]; - return; - - uint64_t qkv_offset = 0; - uint64_t qkv_out_base_offset = col_offset; - uint64_t bias_base_offset = col_offset; - uint64_t split_m_size = weight_dim1 * n_weights; - if (qkv_fused) - { - split_m_size = weight_dim1 * n_weights / 3; - qkv_offset = block_m_start / split_m_size; - qkv_out_base_offset = qkv_offset * input_dim1 * weight_dim0 + col_offset; - bias_base_offset = qkv_offset * weight_dim0 + col_offset; - } - - float tmp_w_res = conversion::to(0.0); - float tmp_z_res = conversion::to(0.0); - float tmp_final_res = conversion::to(0.0); - float tmp_weight_scales; - float tmp_weight_zero; - - uint64_t current_group_size = group_size; - uint64_t scale_dim1_ind = block_m_start / group_size; - - for (uint64_t i = block_m_start; i < block_m_end; i += current_group_size) - { - if (i + current_group_size > block_m_end) - current_group_size = block_m_end - i; - - // // index of weight scale - // uint64_t dind = (i / group_size) * weight_dim0 + col_offset + local_tid; - // int32_t i_zero = - // ((weight_zeros[dind / n_weights] >> (((col_offset + local_tid) & 0xf) * 4)) & 0xf) + 1; - - // tmp_weight_scales = conversion::to(weight_scales[dind]); - // tmp_weight_zero = conversion::to(i_zero); - // if (i + current_group_size > block_m_end) - // current_group_size = block_m_end - i; - - // index of weight scale - uint64_t scale_index = - scale_dim1_ind * weight_dim0 + col_offset + local_tid; - // 4 is 4bits weight. 0xf is mask for 4 bits weight. 1 is for gptq algorithm. - int32_t i_zero = ((weight_zeros[scale_index / n_weights] >> - ((scale_index & 0xf) * 4)) & - 0xf) + - 1; - - tmp_weight_scales = conversion::to(weight_scales[scale_index]); - tmp_weight_zero = conversion::to(i_zero); - scale_dim1_ind += 1; - for (uint64_t j = 0; j < current_group_size; j += group_step) - { - -// compute lookup table -#pragma unroll - for (uint64_t k = 0; k < table_iter; k++) - { - - // uint64_t table_id = k * block_tnum + local_tid; - // uint64_t dind = table_id & 0xff; - // uint64_t tid = table_id >> 8; - // uint64_t input_offset = i + j + tid * 2; - - uint64_t table_id = k * block_tnum + local_tid; - uint64_t weight_id = table_id & 0xff; - uint64_t input_id = table_id >> 8; - // 2 is number of inputs for one table elements. - uint64_t input_offset = (i + j + input_id * 2) % split_m_size; - - // float i1, i2; - - // float i1 = relu(conversion::to(input[input_offset])); - // float i2 = relu(conversion::to(input[input_offset + 1])); - float i1 = (conversion::to(input[input_offset])); - float i2 = (conversion::to(input[input_offset + 1])); - - i2sum[input_id] = i1 + i2; - - int32_t iw1 = weight_id & 0xf; - int32_t iw2 = weight_id >> 4; - - float w1 = conversion::to(iw1); - float w2 = conversion::to(iw2); - - table[input_id][weight_id] = w1 * i1 + w2 * i2; - } - __syncthreads(); -#pragma unroll - for (uint64_t k = 0; k < group_block; k++) - { - - uint64_t base_weight_offset = ((i + j) / n_weights + k) * weight_dim0; - uint64_t dind = base_weight_offset + col_offset + local_tid; - - TW w = weight[dind]; - -#pragma unroll - for (uint64_t z = 0; z < n_weights / 2; z++) - { - uint32_t k1 = k * n_weights / 2 + z; - TW w1 = (w >> (z * 8)) & 0xff; - - tmp_w_res += table[k1][w1]; - tmp_z_res += i2sum[k1]; - } - } - } - - tmp_final_res += - (tmp_w_res - tmp_z_res * tmp_weight_zero) * tmp_weight_scales; - tmp_w_res = conversion::to(0.0); - tmp_z_res = conversion::to(0.0); - } - - - if(col_offset + local_tid < input_dim0 * input_dim1) - { - uint64_t bias_offset = bias_base_offset + local_tid; - float bias_v = 0; - float residual_v = 0; - if (add_bias && blockIdx.y == 0) - { - bias_v = conversion::to(bias[bias_offset]); - tmp_final_res += bias_v; - } - uint64_t dind = qkv_out_base_offset + local_tid; - if(act_type == 1) - { - tmp_final_res = relu(tmp_final_res); - } - else if(act_type == 2) - { - tmp_final_res = gelu(tmp_final_res); - } - else if(act_type == 3) - { - tmp_final_res = silu(tmp_final_res); - } - - if (add_residual && blockIdx.y == 0){ - residual_v = conversion::to(residual[dind]); - tmp_final_res += residual_v; - } - T tmp_res = conversion::to(tmp_final_res); - // float *o = (float*)output; - atomicAdd(&output[dind], tmp_res); - // atomicAdd(&output[dind], tmp_final_res); - - } -} - - - -template -at::Tensor gptq_act_linear_layer(at::Tensor& input, - at::Tensor& weight, - at::Tensor& weight_scales, - at::Tensor& weight_zeros, - at::Tensor& bias, - at::Tensor& residual, - int64_t group_size, - int32_t act_type, - int32_t add_bias, - int32_t add_residual, - int32_t qkv_fused, - uint64_t block_size_x, - uint64_t block_size_y) -{ - - uint64_t input_dim0 = input.sizes()[2]; - uint64_t input_dim1 = input.sizes()[0] * input.sizes()[1]; - - uint64_t weight_dim0 = weight.sizes()[1]; - uint64_t weight_dim1 = weight.sizes()[0]; - - - auto options = - torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA); - - std::vector out_shape; - if (qkv_fused) - out_shape.push_back(3); - out_shape.push_back(input.sizes()[0]); - out_shape.push_back(input.sizes()[1]); - out_shape.push_back(weight.sizes()[1]); - - at::Tensor output = at::zeros(out_shape, options); - - T* input_ptr = (T*)input.data_ptr(); - TW* weight_ptr = (TW*)weight.data_ptr(); - T* weight_scales_ptr = (T*)weight_scales.data_ptr(); - TW* weight_zeros_ptr = (TW*)weight_zeros.data_ptr(); - T* bias_ptr = (T*)bias.data_ptr(); - T* output_ptr = (T*)output.data_ptr(); - T* residual_ptr = (T*)residual.data_ptr(); - // at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(); - auto stream = at::cuda::getCurrentCUDAStream().stream(); -// #define BENCHMARK -#ifdef BENCHMARK - uint32_t block_xs[] = {128, 256, 512}; - uint32_t block_ys[] = {128, 256, 512, 1024}; - - for (uint32_t i = 0; i < 3; i++) - { - for (uint32_t j = 0; j < 4; j++) - { - - block_size_x = block_xs[i]; - block_size_y = block_ys[j]; - uint32_t warm_up = 2; - uint32_t bench = 5; - auto start = std::chrono::high_resolution_clock::now(); - auto end = std::chrono::high_resolution_clock::now(); - for (uint32_t k = 0; k < warm_up + bench; k++) - { - - if (k == warm_up) - start = std::chrono::high_resolution_clock::now(); - -#endif - uint64_t block_size_m = block_size_y; - uint64_t block_size_n = block_size_x; - - uint64_t block_tnum = block_size_x; - - // printf("block size m %d %d\n", weight_dim1, weight_dim0); - // printf("block size m %d %d\n", input_dim1, input_dim0); - - if (input_dim1 == 1) - { - - dim3 block_dim(block_tnum, 1, 1); - dim3 grid_dim(weight_dim0 / block_tnum, - (weight_dim1 * sizeof(TW) * 2 + block_size_y - 1) / block_size_y, - 1); - // printf("block size m %d %d\n", weight_dim1, weight_dim0); - // printf("block size m %d %d\n", input_dim1, input_dim0); - // printf("block size m %d %d\n", block_tnum, weight_dim0 / block_tnum); - // printf("block size m %d %d\n", weight_dim1 * sizeof(TW) * 2 / block_size_y, input_dim0); - gptq_gemm - <<>>(input_ptr, - weight_ptr, - weight_scales_ptr, - weight_zeros_ptr, - bias_ptr, - residual_ptr, - output_ptr, - input_dim0, - input_dim1, - weight_dim0, - weight_dim1, - group_size, - act_type, - add_bias, - add_residual, - qkv_fused, - block_size_m, - block_size_n); - } - else - { - printf("cuda kernel not support batch * seq_len > 1\n"); - } - -#ifdef BENCHMARK - } - end = std::chrono::high_resolution_clock::now(); - double sec = - (double)(std::chrono::duration_cast( - end - start) - .count()) / - 1e9 / 5; - - printf("block x: %d, block y: %d, %.8f\n", - block_size_x, - block_size_y, - sec); - } - } -#endif - // float *o = (float*)output_ptr; - // for(int i = 0; i < 64; i ++){ - // printf("%f ", o[i]); - // } - // printf("\n"); - return output; -} - -#define INSTANTIATE_ACT_GPTQ_LINEAR(T, TW) \ - template at::Tensor gptq_act_linear_layer( \ - at::Tensor & input, \ - at::Tensor & weight, \ - at::Tensor & weight_scales, \ - at::Tensor & weight_zeros, \ - at::Tensor & bias, \ - at::Tensor & residual, \ - int64_t group_size, \ - int32_t act_type, \ - int32_t add_bias, \ - int32_t add_residual, \ - int32_t qkv_fused, \ - uint64_t block_size_x, \ - uint64_t block_size_y); - -// INSTANTIATE_ACT_GPTQ_LINEAR(float, uint64_t) -INSTANTIATE_ACT_GPTQ_LINEAR(__half, uint64_t) -// INSTANTIATE_ACT_GPTQ_LINEAR(float, uint32_t) -INSTANTIATE_ACT_GPTQ_LINEAR(__half, uint32_t) -// INSTANTIATE_ACT_GPTQ_LINEAR(float, uint8_t) -INSTANTIATE_ACT_GPTQ_LINEAR(__half, uint8_t) diff --git a/colossalai/gptq/csrc/includes/conversion_utils.h b/colossalai/gptq/csrc/includes/conversion_utils.h deleted file mode 100644 index 3d31e37de364..000000000000 --- a/colossalai/gptq/csrc/includes/conversion_utils.h +++ /dev/null @@ -1,641 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// SPDX-License-Identifier: Apache-2.0 - -// DeepSpeed Team - -#pragma once - -#include "ds_kernel_utils.h" - -#include -#include - -#ifdef BF16_AVAILABLE -#include -#endif - -namespace conversion { - -// Basic primitive for constructing conversions -template -DS_D_INLINE TO to(FROM val) -{ - return to(val); -} - -// Specializations - -/********************* Identity Conversions *********************/ -/* -Identity conversions are useful in templated functions where we might have -a fixed destination type. For example, I might have a kernel that accepts -__half, __nv_bfloat16, and float but always want to do the core computation -at floating point: - -T mem_value = input[idx]; -float compute_value = conversion::to(mem_value); - -In practice, we should be able to elide the second template parameter: -float compute_val = conversion::to(mem_value); - -In this case, we need an implementation to handle the T = float case - -NOTE: The type inferencing system appears to be unable to handle inferring the first -template parameter, even in the trivial case. -*/ - -// Floating point types -template <> -DS_D_INLINE double to(double val) -{ - return val; -} -template <> -DS_D_INLINE float to(float val) -{ - return val; -} -template <> -DS_D_INLINE __half to(__half val) -{ - return val; -} -#ifdef BF16_AVAILABLE -template <> -DS_D_INLINE __nv_bfloat16 to(__nv_bfloat16 val) -{ - return val; -} -#endif - -// Integer types -template <> -DS_D_INLINE int8_t to(int8_t val) -{ - return val; -} -template <> -DS_D_INLINE uint8_t to(uint8_t val) -{ - return val; -} -template <> -DS_D_INLINE int16_t to(int16_t val) -{ - return val; -} -template <> -DS_D_INLINE uint16_t to(uint16_t val) -{ - return val; -} -template <> -DS_D_INLINE int32_t to(int32_t val) -{ - return val; -} -template <> -DS_D_INLINE uint32_t to(uint32_t val) -{ - return val; -} -template <> -DS_D_INLINE int64_t to(int64_t val) -{ - return val; -} -template <> -DS_D_INLINE uint64_t to(uint64_t val) -{ - return val; -} - -// TODO: evaluate if we want bools - -/********************* To Double Conversions *********************/ - -// * to double variants - -// Would normally like to not use C cast, but this is an important enough conversion -// to keep -template <> -DS_D_INLINE double to(float val) -{ -#ifdef PTX_AVAILABLE - double ret_val; - asm("ctv.rn.f64.f32 %0, %1;\n" : "=d"(ret_val) : "f"(val)); - return ret_val; -#else - return double(val); -#endif -} -// Note: there is a CVT instruction for __half -> double, but there's no inline interface -// for passing a single half value -template <> -DS_D_INLINE double to(__half val) -{ - return to(__half2float(val)); -} -template <> -DS_D_INLINE double to(int64_t val) -{ - return __ll2double_rn(val); -} -template <> -DS_D_INLINE double to(int32_t val) -{ - return __int2double_rn(val); -} -template <> -DS_D_INLINE double to(int16_t val) -{ - return __int2double_rn(val); -} -template <> -DS_D_INLINE double to(int8_t val) -{ - return __int2double_rn(val); -} -template <> -DS_D_INLINE double to(uint64_t val) -{ - return __ull2double_rn(val); -} -template <> -DS_D_INLINE double to(uint32_t val) -{ - return __uint2double_rn(val); -} -template <> -DS_D_INLINE double to(uint16_t val) -{ - return __uint2double_rn(val); -} -template <> -DS_D_INLINE double to(uint8_t val) -{ - return __uint2double_rn(val); -} - -// Same applies here -#ifdef BF16_AVAILABLE -template <> -DS_D_INLINE double to(__nv_bfloat16 val) -{ - return to(__bfloat162float(val)); -} -#endif - -/********************* To Float Conversions *********************/ - -template <> -DS_D_INLINE float to(double val) -{ - return __double2float_rn(val); -} -template <> -DS_D_INLINE float to(__half val) -{ - return __half2float(val); -} -template <> -DS_D_INLINE float to(int64_t val) -{ - return __ll2float_rn(val); -} -template <> -DS_D_INLINE float to(int32_t val) -{ - return __int2float_rn(val); -} -template <> -DS_D_INLINE float to(int16_t val) -{ - return __int2float_rn(val); -} -template <> -DS_D_INLINE float to(int8_t val) -{ - return __int2float_rn(val); -} -template <> -DS_D_INLINE float to(uint64_t val) -{ - return __ull2float_rn(val); -} -template <> -DS_D_INLINE float to(uint32_t val) -{ - return __uint2float_rn(val); -} -template <> -DS_D_INLINE float to(uint16_t val) -{ - return __uint2float_rn(val); -} -template <> -DS_D_INLINE float to(uint8_t val) -{ - return __uint2float_rn(val); -} - -#ifdef BF16_AVAILABLE -template <> -DS_D_INLINE float to(__nv_bfloat16 val) -{ - return __bfloat162float(val); -} -#endif - -/********************* To Float2 Conversions *********************/ -template <> -DS_D_INLINE float2 to(__half2 val) -{ - return __half22float2(val); -} - -#ifdef BF16_AVAILABLE -template <> -DS_D_INLINE float2 to(__nv_bfloat162 val) -{ - return __bfloat1622float2(val); -} -#endif - -/********************* To Half Conversions *********************/ -template <> -DS_D_INLINE __half to(double val) -{ -#ifdef __HIP_PLATFORM_HCC__ - float val_f = __double2float_rn(val); - return __float2half(val_f); -#else - return __double2half(val); -#endif -} -template <> -DS_D_INLINE __half to(float val) -{ - return __float2half(val); -} -template <> -DS_D_INLINE __half to(int64_t val) -{ - return __ll2half_rn(val); -} -template <> -DS_D_INLINE __half to(int32_t val) -{ - return __int2half_rn(val); -} -template <> -DS_D_INLINE __half to(int16_t val) -{ - return __short2half_rn(val); -} -template <> -DS_D_INLINE __half to(int8_t val) -{ - return __int2half_rn(val); -} -template <> -DS_D_INLINE __half to(uint64_t val) -{ - return __ull2half_rn(val); -} -template <> -DS_D_INLINE __half to(uint32_t val) -{ - return __uint2half_rn(val); -} -template <> -DS_D_INLINE __half to(uint16_t val) -{ - return __ushort2half_rn(val); -} -template <> -DS_D_INLINE __half to(uint8_t val) -{ - return __uint2half_rn(val); -} - -#ifdef BF16_AVAILABLE -// No direct conversion -template <> -DS_D_INLINE __half to(__nv_bfloat16 val) -{ - return to<__half>(to(val)); -} -#endif - -/********************* To Half2 Conversions *********************/ -template <> -DS_D_INLINE __half2 to(float2 val) -{ - return __float22half2_rn(val); -} -template <> -DS_D_INLINE __half2 to(float val) -{ - return __float2half2_rn(val); -} - -#ifdef BF16_AVAILABLE -// No direct conversion -template <> -DS_D_INLINE __half2 to(__nv_bfloat162 val) -{ - return to<__half2>(to(val)); -} -#endif - -/********************* To BF16 Conversions *********************/ -#ifdef BF16_AVAILABLE -template <> -DS_D_INLINE __nv_bfloat16 to(double val) -{ - return __double2bfloat16(val); -} -template <> -DS_D_INLINE __nv_bfloat16 to(float val) -{ - return __float2bfloat16(val); -} -template <> -DS_D_INLINE __nv_bfloat16 to(int64_t val) -{ - return __ll2bfloat16_rn(val); -} -template <> -DS_D_INLINE __nv_bfloat16 to(int32_t val) -{ - return __int2bfloat16_rn(val); -} -template <> -DS_D_INLINE __nv_bfloat16 to(int16_t val) -{ - return __short2bfloat16_rn(val); -} -template <> -DS_D_INLINE __nv_bfloat16 to(int8_t val) -{ - return __int2bfloat16_rn(val); -} -template <> -DS_D_INLINE __nv_bfloat16 to(uint64_t val) -{ - return __ull2bfloat16_rn(val); -} -template <> -DS_D_INLINE __nv_bfloat16 to(uint32_t val) -{ - return __uint2bfloat16_rn(val); -} -template <> -DS_D_INLINE __nv_bfloat16 to(uint16_t val) -{ - return __ushort2bfloat16_rn(val); -} -template <> -DS_D_INLINE __nv_bfloat16 to(uint8_t val) -{ - return __uint2bfloat16_rn(val); -} -#endif - -/********************* To BF162 Conversions *********************/ -#ifdef BF16_AVAILABLE -template <> -DS_D_INLINE __nv_bfloat162 to(float2 val) -{ - return __float22bfloat162_rn(val); -} -template <> -DS_D_INLINE __nv_bfloat162 to(float val) -{ - return __float2bfloat162_rn(val); -} -template <> -DS_D_INLINE __nv_bfloat162 to(__half2 val) -{ - return to<__nv_bfloat162>(to(val)); -} -#endif - -/********************* To INT64_T Conversions *********************/ -template <> -DS_D_INLINE int64_t to(double val) -{ - return __double2ll_rn(val); -} -template <> -DS_D_INLINE int64_t to(float val) -{ - return __float2ll_rn(val); -} -template <> -DS_D_INLINE int64_t to(__half val) -{ - return __half2ll_rn(val); -} -// No direct support for integer casts at the C++ level and I don't feel they're so important -// to demand an PTX at this time - -#ifdef BF16_AVAILABLE -template <> -DS_D_INLINE int64_t to(__nv_bfloat16 val) -{ - return __bfloat162ll_rn(val); -} -#endif - -/********************* To INT32_T Conversions *********************/ -template <> -DS_D_INLINE int32_t to(double val) -{ - return __double2int_rn(val); -} -template <> -DS_D_INLINE int32_t to(float val) -{ - return __float2int_rn(val); -} -template <> -DS_D_INLINE int32_t to(__half val) -{ - return __half2int_rn(val); -} -// No direct support for integer casts at the C++ level and I don't feel they're so important -// to demand an PTX at this time - -#ifdef BF16_AVAILABLE -template <> -DS_D_INLINE int32_t to(__nv_bfloat16 val) -{ - return __bfloat162int_rn(val); -} -#endif - -/********************* To INT16_T Conversions *********************/ -template <> -DS_D_INLINE int16_t to(double val) -{ - return __double2int_rn(val); -} -template <> -DS_D_INLINE int16_t to(float val) -{ - return __float2int_rn(val); -} -template <> -DS_D_INLINE int16_t to(__half val) -{ - return __half2int_rn(val); -} -// No direct support for integer casts at the C++ level and I don't feel they're so important -// to demand an PTX at this time - -#ifdef BF16_AVAILABLE -template <> -DS_D_INLINE int16_t to(__nv_bfloat16 val) -{ - return __bfloat162int_rn(val); -} -#endif - -/********************* To INT8_T Conversions *********************/ -template <> -DS_D_INLINE int8_t to(double val) -{ - return __double2int_rn(val); -} -template <> -DS_D_INLINE int8_t to(float val) -{ - return __float2int_rn(val); -} -template <> -DS_D_INLINE int8_t to(__half val) -{ - return __half2int_rn(val); -} -// No direct support for integer casts at the C++ level and I don't feel they're so important -// to demand an PTX at this time - -#ifdef BF16_AVAILABLE -template <> -DS_D_INLINE int8_t to(__nv_bfloat16 val) -{ - return __bfloat162int_rn(val); -} -#endif - -/********************* To UINT64_T Conversions *********************/ -template <> -DS_D_INLINE uint64_t to(double val) -{ - return __double2ull_rn(val); -} -template <> -DS_D_INLINE uint64_t to(float val) -{ - return __float2ull_rn(val); -} -template <> -DS_D_INLINE uint64_t to(__half val) -{ - return __half2ull_rn(val); -} -// No direct support for integer casts at the C++ level and I don't feel they're so important -// to demand an PTX at this time - -#ifdef BF16_AVAILABLE -template <> -DS_D_INLINE uint64_t to(__nv_bfloat16 val) -{ - return __bfloat162ull_rn(val); -} -#endif - -/********************* To UINT32_T Conversions *********************/ -template <> -DS_D_INLINE uint32_t to(double val) -{ - return __double2uint_rn(val); -} -template <> -DS_D_INLINE uint32_t to(float val) -{ - return __float2uint_rn(val); -} -template <> -DS_D_INLINE uint32_t to(__half val) -{ - return __half2uint_rn(val); -} -// No direct support for integer casts at the C++ level and I don't feel they're so important -// to demand an PTX at this time - -#ifdef BF16_AVAILABLE -template <> -DS_D_INLINE uint32_t to(__nv_bfloat16 val) -{ - return __bfloat162uint_rn(val); -} -#endif - -/********************* To UINT16_T Conversions *********************/ -template <> -DS_D_INLINE uint16_t to(double val) -{ - return __double2uint_rn(val); -} -template <> -DS_D_INLINE uint16_t to(float val) -{ - return __float2uint_rn(val); -} -template <> -DS_D_INLINE uint16_t to(__half val) -{ - return __half2uint_rn(val); -} -// No direct support for integer casts at the C++ level and I don't feel they're so important -// to demand an PTX at this time - -#ifdef BF16_AVAILABLE -template <> -DS_D_INLINE uint16_t to(__nv_bfloat16 val) -{ - return __bfloat162uint_rn(val); -} -#endif - -/********************* To UINT8_T Conversions *********************/ -template <> -DS_D_INLINE uint8_t to(double val) -{ - return __double2uint_rn(val); -} -template <> -DS_D_INLINE uint8_t to(float val) -{ - return __float2uint_rn(val); -} -template <> -DS_D_INLINE uint8_t to(__half val) -{ - return __half2uint_rn(val); -} -// No direct support for integer casts at the C++ level and I don't feel they're so important -// to demand an PTX at this time - -#ifdef BF16_AVAILABLE -template <> -DS_D_INLINE uint8_t to(__nv_bfloat16 val) -{ - return __bfloat162uint_rn(val); -} -#endif - -} // namespace conversion \ No newline at end of file diff --git a/colossalai/gptq/csrc/includes/ds_kernel_utils.h b/colossalai/gptq/csrc/includes/ds_kernel_utils.h deleted file mode 100644 index 99d8be17e503..000000000000 --- a/colossalai/gptq/csrc/includes/ds_kernel_utils.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// SPDX-License-Identifier: Apache-2.0 - -// DeepSpeed Team - -/* -Centralized header file for preprocessor macros and constants -used throughout the codebase. -*/ - -#pragma once - -#include - -#define DS_HD_INLINE __host__ __device__ __forceinline__ -#define DS_D_INLINE __device__ __forceinline__ - -#ifdef __HIP_PLATFORM_HCC__ - -// constexpr variant of warpSize for templating -constexpr int hw_warp_size = 64; -#define HALF_PRECISION_AVAILABLE = 1 -#include - -#else // !__HIP_PLATFORM_HCC__ - -// constexpr variant of warpSize for templating -constexpr int hw_warp_size = 32; - -#if __CUDA_ARCH__ >= 530 -#define HALF_PRECISION_AVAILABLE = 1 -#define PTX_AVAILABLE -#endif // __CUDA_ARCH__ >= 530 - -#if __CUDA_ARCH__ >= 800 -#define ASYNC_COPY_AVAILABLE -#define BF16_AVAILABLE -#endif // __CUDA_ARCH__ >= 800 - -#include - -#endif //__HIP_PLATFORM_HCC__ - -inline int next_pow2(const int val) -{ - int rounded_val = val - 1; - rounded_val |= rounded_val >> 1; - rounded_val |= rounded_val >> 2; - rounded_val |= rounded_val >> 4; - rounded_val |= rounded_val >> 8; - return rounded_val + 1; -} \ No newline at end of file diff --git a/colossalai/gptq/csrc/includes/inference_cuda_layers.h b/colossalai/gptq/csrc/includes/inference_cuda_layers.h deleted file mode 100644 index 08ea00c6558c..000000000000 --- a/colossalai/gptq/csrc/includes/inference_cuda_layers.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// SPDX-License-Identifier: Apache-2.0 - -// DeepSpeed Team - -#pragma once - -#include -#ifdef BF16_AVAILABLE -#include -#endif -#include -#include -#include -#include -#include -#include - -template -at::Tensor gptq_act_linear_layer(at::Tensor& input, - at::Tensor& weight, - at::Tensor& weight_scales, - at::Tensor& weight_zeros, - at::Tensor& bias, - at::Tensor& residual, - int64_t group_size, - int32_t act_type, - int32_t add_bias, - int32_t add_residual, - int32_t qkv_fused, - uint64_t block_size_x, - uint64_t block_size_y); \ No newline at end of file diff --git a/colossalai/gptq/csrc/pt_binding.cpp b/colossalai/gptq/csrc/pt_binding.cpp deleted file mode 100644 index 80e96fde6673..000000000000 --- a/colossalai/gptq/csrc/pt_binding.cpp +++ /dev/null @@ -1,23 +0,0 @@ -#include "inference_cuda_layers.h" -#include -#include -#include -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - - m.def("gptq_act_linear_fp16", - &gptq_act_linear_layer<__half, uint64_t>, - "gptq linear kernel (CUDA)"); - - - m.def("gptq_act_linear_fp16_w32", - &gptq_act_linear_layer<__half, uint32_t>, - "gptq linear kernel (CUDA)"); - - m.def("gptq_act_linear_fp16_w8", - &gptq_act_linear_layer<__half, uint8_t>, - "gptq linear kernel (CUDA)"); - -} diff --git a/colossalai/gptq/inference_builder.py b/colossalai/gptq/inference_builder.py deleted file mode 100644 index 60cb7167a76e..000000000000 --- a/colossalai/gptq/inference_builder.py +++ /dev/null @@ -1,761 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team - -import os -import sys -import time -import importlib -from pathlib import Path -import subprocess -import shlex -import shutil -import tempfile -import distutils.ccompiler -import distutils.log -import distutils.sysconfig -from distutils.errors import CompileError, LinkError -from abc import ABC, abstractmethod -from typing import List - -YELLOW = '\033[93m' -END = '\033[0m' -WARNING = f"{YELLOW} [WARNING] {END}" - -DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions" -DEFAULT_COMPUTE_CAPABILITIES = "6.0;6.1;7.0" - -try: - import torch -except ImportError: - print(f"{WARNING} unable to import torch, please install it if you want to pre-compile any deepspeed ops.") -else: - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) - - -def installed_cuda_version(name=""): - import torch.utils.cpp_extension - cuda_home = torch.utils.cpp_extension.CUDA_HOME - assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)" - # Ensure there is not a cuda version mismatch between torch and nvcc compiler - output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True) - output_split = output.split() - release_idx = output_split.index("release") - release = output_split[release_idx + 1].replace(',', '').split(".") - # Ignore patch versions, only look at major + minor - cuda_major, cuda_minor = release[:2] - return int(cuda_major), int(cuda_minor) - - -def get_default_compute_capabilities(): - compute_caps = DEFAULT_COMPUTE_CAPABILITIES - import torch.utils.cpp_extension - if torch.utils.cpp_extension.CUDA_HOME is not None and installed_cuda_version()[0] >= 11: - if installed_cuda_version()[0] == 11 and installed_cuda_version()[1] == 0: - # Special treatment of CUDA 11.0 because compute_86 is not supported. - compute_caps += ";8.0" - else: - compute_caps += ";8.0;8.6" - return compute_caps - - -# list compatible minor CUDA versions - so that for example pytorch built with cuda-11.0 can be used -# to build deepspeed and system-wide installed cuda 11.2 -cuda_minor_mismatch_ok = { - 10: [ - "10.0", - "10.1", - "10.2", - ], - 11: ["11.0", "11.1", "11.2", "11.3", "11.4", "11.5", "11.6", "11.7", "11.8"], -} - - -def assert_no_cuda_mismatch(name=""): - cuda_major, cuda_minor = installed_cuda_version(name) - sys_cuda_version = f'{cuda_major}.{cuda_minor}' - torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) - # This is a show-stopping error, should probably not proceed past this - if sys_cuda_version != torch_cuda_version: - if (cuda_major in cuda_minor_mismatch_ok and sys_cuda_version in cuda_minor_mismatch_ok[cuda_major] - and torch_cuda_version in cuda_minor_mismatch_ok[cuda_major]): - print(f"Installed CUDA version {sys_cuda_version} does not match the " - f"version torch was compiled with {torch.version.cuda} " - "but since the APIs are compatible, accepting this combination") - return True - raise Exception(f">- DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the " - f"version torch was compiled with {torch.version.cuda}, unable to compile " - "cuda/cpp extensions without a matching cuda version.") - return True - - -class OpBuilder(ABC): - _rocm_version = None - _is_rocm_pytorch = None - - def __init__(self, name): - self.name = name - self.jit_mode = False - self.build_for_cpu = False - self.error_log = None - - @abstractmethod - def absolute_name(self): - ''' - Returns absolute build path for cases where the op is pre-installed, e.g., deepspeed.ops.adam.cpu_adam - will be installed as something like: deepspeed/ops/adam/cpu_adam.so - ''' - pass - - @abstractmethod - def sources(self): - ''' - Returns list of source files for your op, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed) - ''' - pass - - def hipify_extension(self): - pass - - @staticmethod - def validate_torch_version(torch_info): - install_torch_version = torch_info['version'] - current_torch_version = ".".join(torch.__version__.split('.')[:2]) - if install_torch_version != current_torch_version: - raise RuntimeError("PyTorch version mismatch! DeepSpeed ops were compiled and installed " - "with a different version than what is being used at runtime. " - f"Please re-install DeepSpeed or switch torch versions. " - f"Install torch version={install_torch_version}, " - f"Runtime torch version={current_torch_version}") - - @staticmethod - def validate_torch_op_version(torch_info): - if not OpBuilder.is_rocm_pytorch(): - current_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) - install_cuda_version = torch_info['cuda_version'] - if install_cuda_version != current_cuda_version: - raise RuntimeError("CUDA version mismatch! DeepSpeed ops were compiled and installed " - "with a different version than what is being used at runtime. " - f"Please re-install DeepSpeed or switch torch versions. " - f"Install CUDA version={install_cuda_version}, " - f"Runtime CUDA version={current_cuda_version}") - else: - current_hip_version = ".".join(torch.version.hip.split('.')[:2]) - install_hip_version = torch_info['hip_version'] - if install_hip_version != current_hip_version: - raise RuntimeError("HIP version mismatch! DeepSpeed ops were compiled and installed " - "with a different version than what is being used at runtime. " - f"Please re-install DeepSpeed or switch torch versions. " - f"Install HIP version={install_hip_version}, " - f"Runtime HIP version={current_hip_version}") - - @staticmethod - def is_rocm_pytorch(): - if OpBuilder._is_rocm_pytorch is not None: - return OpBuilder._is_rocm_pytorch - - _is_rocm_pytorch = False - try: - import torch - except ImportError: - pass - else: - if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5): - _is_rocm_pytorch = hasattr(torch.version, 'hip') and torch.version.hip is not None - if _is_rocm_pytorch: - from torch.utils.cpp_extension import ROCM_HOME - _is_rocm_pytorch = ROCM_HOME is not None - OpBuilder._is_rocm_pytorch = _is_rocm_pytorch - return OpBuilder._is_rocm_pytorch - - @staticmethod - def installed_rocm_version(): - if OpBuilder._rocm_version: - return OpBuilder._rocm_version - - ROCM_MAJOR = '0' - ROCM_MINOR = '0' - if OpBuilder.is_rocm_pytorch(): - from torch.utils.cpp_extension import ROCM_HOME - rocm_ver_file = Path(ROCM_HOME).joinpath(".info/version-dev") - if rocm_ver_file.is_file(): - with open(rocm_ver_file, 'r') as file: - ROCM_VERSION_DEV_RAW = file.read() - elif "rocm" in torch.__version__: - ROCM_VERSION_DEV_RAW = torch.__version__.split("rocm")[1] - else: - assert False, "Could not detect ROCm version" - assert ROCM_VERSION_DEV_RAW != "", "Could not detect ROCm version" - ROCM_MAJOR = ROCM_VERSION_DEV_RAW.split('.')[0] - ROCM_MINOR = ROCM_VERSION_DEV_RAW.split('.')[1] - OpBuilder._rocm_version = (int(ROCM_MAJOR), int(ROCM_MINOR)) - return OpBuilder._rocm_version - - def include_paths(self): - ''' - Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed) - ''' - return [] - - def nvcc_args(self): - ''' - Returns optional list of compiler flags to forward to nvcc when building CUDA sources - ''' - return [] - - def cxx_args(self): - ''' - Returns optional list of compiler flags to forward to the build - ''' - return [] - - def is_compatible(self, verbose=True): - ''' - Check if all non-python dependencies are satisfied to build this op - ''' - return True - - def extra_ldflags(self): - return [] - - def libraries_installed(self, libraries): - valid = False - check_cmd = 'dpkg -l' - for lib in libraries: - result = subprocess.Popen(f'dpkg -l {lib}', stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) - valid = valid or result.wait() == 0 - return valid - - def has_function(self, funcname, libraries, verbose=False): - ''' - Test for existence of a function within a tuple of libraries. - - This is used as a smoke test to check whether a certain library is available. - As a test, this creates a simple C program that calls the specified function, - and then distutils is used to compile that program and link it with the specified libraries. - Returns True if both the compile and link are successful, False otherwise. - ''' - tempdir = None # we create a temporary directory to hold various files - filestderr = None # handle to open file to which we redirect stderr - oldstderr = None # file descriptor for stderr - try: - # Echo compile and link commands that are used. - if verbose: - distutils.log.set_verbosity(1) - - # Create a compiler object. - compiler = distutils.ccompiler.new_compiler(verbose=verbose) - - # Configure compiler and linker to build according to Python install. - distutils.sysconfig.customize_compiler(compiler) - - # Create a temporary directory to hold test files. - tempdir = tempfile.mkdtemp() - - # Define a simple C program that calls the function in question - prog = "void %s(void); int main(int argc, char** argv) { %s(); return 0; }" % (funcname, funcname) - - # Write the test program to a file. - filename = os.path.join(tempdir, 'test.c') - with open(filename, 'w') as f: - f.write(prog) - - # Redirect stderr file descriptor to a file to silence compile/link warnings. - if not verbose: - filestderr = open(os.path.join(tempdir, 'stderr.txt'), 'w') - oldstderr = os.dup(sys.stderr.fileno()) - os.dup2(filestderr.fileno(), sys.stderr.fileno()) - - # Workaround for behavior in distutils.ccompiler.CCompiler.object_filenames() - # Otherwise, a local directory will be used instead of tempdir - drive, driveless_filename = os.path.splitdrive(filename) - root_dir = driveless_filename[0] if os.path.isabs(driveless_filename) else '' - output_dir = os.path.join(drive, root_dir) - - # Attempt to compile the C program into an object file. - cflags = shlex.split(os.environ.get('CFLAGS', "")) - objs = compiler.compile([filename], output_dir=output_dir, extra_preargs=self.strip_empty_entries(cflags)) - - # Attempt to link the object file into an executable. - # Be sure to tack on any libraries that have been specified. - ldflags = shlex.split(os.environ.get('LDFLAGS', "")) - compiler.link_executable(objs, - os.path.join(tempdir, 'a.out'), - extra_preargs=self.strip_empty_entries(ldflags), - libraries=libraries) - - # Compile and link succeeded - return True - - except CompileError: - return False - - except LinkError: - return False - - except: - return False - - finally: - # Restore stderr file descriptor and close the stderr redirect file. - if oldstderr is not None: - os.dup2(oldstderr, sys.stderr.fileno()) - if filestderr is not None: - filestderr.close() - - # Delete the temporary directory holding the test program and stderr files. - if tempdir is not None: - shutil.rmtree(tempdir) - - def strip_empty_entries(self, args): - ''' - Drop any empty strings from the list of compile and link flags - ''' - return [x for x in args if len(x) > 0] - - def cpu_arch(self): - try: - from cpuinfo import get_cpu_info - except ImportError as e: - cpu_info = self._backup_cpuinfo() - if cpu_info is None: - return "-march=native" - - try: - cpu_info = get_cpu_info() - except Exception as e: - self.warning(f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), " - "falling back to `lscpu` to get this information.") - cpu_info = self._backup_cpuinfo() - if cpu_info is None: - return "-march=native" - - if cpu_info['arch'].startswith('PPC_'): - # gcc does not provide -march on PowerPC, use -mcpu instead - return '-mcpu=native' - return '-march=native' - - def is_cuda_enable(self): - try: - assert_no_cuda_mismatch(self.name) - return '-D__ENABLE_CUDA__' - except BaseException: - print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, " - "only cpu ops can be compiled!") - return '-D__DISABLE_CUDA__' - return '-D__DISABLE_CUDA__' - - def _backup_cpuinfo(self): - # Construct cpu_info dict from lscpu that is similar to what py-cpuinfo provides - if not self.command_exists('lscpu'): - self.warning(f"{self.name} attempted to query 'lscpu' after failing to use py-cpuinfo " - "to detect the CPU architecture. 'lscpu' does not appear to exist on " - "your system, will fall back to use -march=native and non-vectorized execution.") - return None - result = subprocess.check_output('lscpu', shell=True) - result = result.decode('utf-8').strip().lower() - - cpu_info = {} - cpu_info['arch'] = None - cpu_info['flags'] = "" - if 'genuineintel' in result or 'authenticamd' in result: - cpu_info['arch'] = 'X86_64' - if 'avx512' in result: - cpu_info['flags'] += 'avx512,' - elif 'avx512f' in result: - cpu_info['flags'] += 'avx512f,' - if 'avx2' in result: - cpu_info['flags'] += 'avx2' - elif 'ppc64le' in result: - cpu_info['arch'] = "PPC_" - - return cpu_info - - def simd_width(self): - try: - from cpuinfo import get_cpu_info - except ImportError as e: - cpu_info = self._backup_cpuinfo() - if cpu_info is None: - return '-D__SCALAR__' - - try: - cpu_info = get_cpu_info() - except Exception as e: - self.warning(f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), " - "falling back to `lscpu` to get this information.") - cpu_info = self._backup_cpuinfo() - if cpu_info is None: - return '-D__SCALAR__' - - if cpu_info['arch'] == 'X86_64': - if 'avx512' in cpu_info['flags'] or 'avx512f' in cpu_info['flags']: - return '-D__AVX512__' - elif 'avx2' in cpu_info['flags']: - return '-D__AVX256__' - return '-D__SCALAR__' - - def command_exists(self, cmd): - if '|' in cmd: - cmds = cmd.split("|") - else: - cmds = [cmd] - valid = False - for cmd in cmds: - result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True) - valid = valid or result.wait() == 0 - - if not valid and len(cmds) > 1: - print(f"{WARNING} {self.name} requires one of the following commands '{cmds}', but it does not exist!") - elif not valid and len(cmds) == 1: - print(f"{WARNING} {self.name} requires the '{cmd}' command, but it does not exist!") - return valid - - def warning(self, msg): - self.error_log = f"{msg}" - print(f"{WARNING} {msg}") - - def deepspeed_src_path(self, code_path): - if os.path.isabs(code_path): - return code_path - else: - return os.path.join(Path(__file__).parent.parent.absolute(), code_path) - - def builder(self): - from torch.utils.cpp_extension import CppExtension - return CppExtension(name=self.absolute_name(), - sources=self.strip_empty_entries(self.sources()), - include_dirs=self.strip_empty_entries(self.include_paths()), - extra_compile_args={'cxx': self.strip_empty_entries(self.cxx_args())}, - extra_link_args=self.strip_empty_entries(self.extra_ldflags())) - - def load(self, verbose=True): - return self.jit_load(verbose) - - def jit_load(self, verbose=True): - if not self.is_compatible(verbose): - raise RuntimeError( - f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue. {self.error_log}" - ) - try: - import ninja # noqa: F401 - except ImportError: - raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.") - - if isinstance(self, CUDAOpBuilder) and not self.is_rocm_pytorch(): - try: - assert_no_cuda_mismatch(self.name) - self.build_for_cpu = False - except BaseException: - self.build_for_cpu = True - - self.jit_mode = True - from torch.utils.cpp_extension import load - - start_build = time.time() - sources = [self.deepspeed_src_path(path) for path in self.sources()] - extra_include_paths = [self.deepspeed_src_path(path) for path in self.include_paths()] - - # Torch will try and apply whatever CCs are in the arch list at compile time, - # we have already set the intended targets ourselves we know that will be - # needed at runtime. This prevents CC collisions such as multiple __half - # implementations. Stash arch list to reset after build. - torch_arch_list = None - if "TORCH_CUDA_ARCH_LIST" in os.environ: - torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST") - os.environ["TORCH_CUDA_ARCH_LIST"] = "" - - op_module = load(name=self.name, - sources=self.strip_empty_entries(sources), - extra_include_paths=self.strip_empty_entries(extra_include_paths), - extra_cflags=self.strip_empty_entries(self.cxx_args()), - extra_cuda_cflags=self.strip_empty_entries(self.nvcc_args()), - extra_ldflags=self.strip_empty_entries(self.extra_ldflags()), - verbose=verbose) - - build_duration = time.time() - start_build - if verbose: - print(f"Time to load {self.name} op: {build_duration} seconds") - - # Reset arch list so we are not silently removing it for other possible use cases - if torch_arch_list: - os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list - - return op_module - - -class CUDAOpBuilder(OpBuilder): - - def compute_capability_args(self, cross_compile_archs=None): - """ - Returns nvcc compute capability compile flags. - - 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`. - 2. If neither is set default compute capabilities will be used - 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX - - Format: - - - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples: - - TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ... - TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ... - - - `cross_compile_archs` uses ; separator. - - """ - ccs = [] - if self.jit_mode: - # Compile for underlying architectures since we know those at runtime - for i in range(torch.cuda.device_count()): - CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i) - cc = f"{CC_MAJOR}.{CC_MINOR}" - if cc not in ccs: - ccs.append(cc) - ccs = sorted(ccs) - ccs[-1] += '+PTX' - else: - # Cross-compile mode, compile for various architectures - # env override takes priority - cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', None) - if cross_compile_archs_env is not None: - if cross_compile_archs is not None: - print( - f"{WARNING} env var `TORCH_CUDA_ARCH_LIST={cross_compile_archs_env}` overrides `cross_compile_archs={cross_compile_archs}`" - ) - cross_compile_archs = cross_compile_archs_env.replace(' ', ';') - else: - if cross_compile_archs is None: - cross_compile_archs = get_default_compute_capabilities() - ccs = cross_compile_archs.split(';') - - ccs = self.filter_ccs(ccs) - if len(ccs) == 0: - raise RuntimeError( - f"Unable to load {self.name} op due to no compute capabilities remaining after filtering") - - args = [] - for cc in ccs: - num = cc[0] + cc[2] - args.append(f'-gencode=arch=compute_{num},code=sm_{num}') - if cc.endswith('+PTX'): - args.append(f'-gencode=arch=compute_{num},code=compute_{num}') - - return args - - def filter_ccs(self, ccs: List[str]): - """ - Prune any compute capabilities that are not compatible with the builder. Should log - which CCs have been pruned. - """ - return ccs - - def version_dependent_macros(self): - # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456 - version_ge_1_1 = [] - if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0): - version_ge_1_1 = ['-DVERSION_GE_1_1'] - version_ge_1_3 = [] - if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): - version_ge_1_3 = ['-DVERSION_GE_1_3'] - version_ge_1_5 = [] - if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): - version_ge_1_5 = ['-DVERSION_GE_1_5'] - return version_ge_1_1 + version_ge_1_3 + version_ge_1_5 - - def is_compatible(self, verbose=True): - return super().is_compatible(verbose) - - def builder(self): - try: - assert_no_cuda_mismatch(self.name) - self.build_for_cpu = False - except BaseException: - self.build_for_cpu = True - - if self.build_for_cpu: - from torch.utils.cpp_extension import CppExtension as ExtensionBuilder - else: - from torch.utils.cpp_extension import CUDAExtension as ExtensionBuilder - - compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} if self.build_for_cpu else \ - {'cxx': self.strip_empty_entries(self.cxx_args()), \ - 'nvcc': self.strip_empty_entries(self.nvcc_args())} - - cuda_ext = ExtensionBuilder(name=self.absolute_name(), - sources=self.strip_empty_entries(self.sources()), - include_dirs=self.strip_empty_entries(self.include_paths()), - libraries=self.strip_empty_entries(self.libraries_args()), - extra_compile_args=compile_args) - - if self.is_rocm_pytorch(): - # hip converts paths to absolute, this converts back to relative - sources = cuda_ext.sources - curr_file = Path(__file__).parent.parent # ds root - for i in range(len(sources)): - src = Path(sources[i]) - if src.is_absolute(): - sources[i] = str(src.relative_to(curr_file)) - else: - sources[i] = str(src) - cuda_ext.sources = sources - return cuda_ext - - def hipify_extension(self): - if self.is_rocm_pytorch(): - from torch.utils.hipify import hipify_python - hipify_python.hipify( - project_directory=os.getcwd(), - output_directory=os.getcwd(), - header_include_dirs=self.include_paths(), - includes=[os.path.join(os.getcwd(), '*')], - extra_files=[os.path.abspath(s) for s in self.sources()], - show_detailed=True, - is_pytorch_extension=True, - hipify_extra_files_only=True, - ) - - def cxx_args(self): - if sys.platform == "win32": - return ['-O2'] - else: - return ['-O3', '-std=c++14', '-g', '-Wno-reorder'] - - def nvcc_args(self): - if self.build_for_cpu: - return [] - args = ['-O3'] - if self.is_rocm_pytorch(): - ROCM_MAJOR, ROCM_MINOR = self.installed_rocm_version() - args += [ - '-std=c++14', '-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_CONVERSIONS__', - '-U__HIP_NO_HALF2_OPERATORS__', - '-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR, - '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR - ] - else: - cuda_major, _ = installed_cuda_version() - args += [ - '-allow-unsupported-compiler' if sys.platform == "win32" else '', '--use_fast_math', - '-std=c++17' if sys.platform == "win32" and cuda_major > 10 else '-std=c++14', - '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__' - ] - if os.environ.get('DS_DEBUG_CUDA_BUILD', '0') == '1': - args.append('--ptxas-options=-v') - args += self.compute_capability_args() - return args - - def libraries_args(self): - if self.build_for_cpu: - return [] - - if sys.platform == "win32": - return ['cublas', 'curand'] - else: - return [] - - -class TorchCPUOpBuilder(CUDAOpBuilder): - - def extra_ldflags(self): - if self.build_for_cpu: - return ['-fopenmp'] - - if not self.is_rocm_pytorch(): - return ['-lcurand'] - - return [] - - def cxx_args(self): - import torch - args = [] - if not self.build_for_cpu: - if not self.is_rocm_pytorch(): - CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64") - else: - CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib") - - args += super().cxx_args() - args += [ - f'-L{CUDA_LIB64}', - '-lcudart', - '-lcublas', - '-g', - ] - - CPU_ARCH = self.cpu_arch() - SIMD_WIDTH = self.simd_width() - CUDA_ENABLE = self.is_cuda_enable() - args += [ - CPU_ARCH, - '-fopenmp', - SIMD_WIDTH, - CUDA_ENABLE, - ] - - return args - -class InferenceBuilder(CUDAOpBuilder): - BUILD_VAR = "DS_BUILD_TRANSFORMER_INFERENCE" - NAME = "transformer_inference" - - def __init__(self, name=None): - name = self.NAME if name is None else name - super().__init__(name=name) - - def absolute_name(self): - return f'cai.inference.{self.NAME}_op' - - def is_compatible(self, verbose=True): - try: - import torch - except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") - return False - - cuda_okay = True - if not self.is_rocm_pytorch() and torch.cuda.is_available(): - sys_cuda_major, _ = installed_cuda_version() - torch_cuda_major = int(torch.version.cuda.split('.')[0]) - cuda_capability = torch.cuda.get_device_properties(0).major - if cuda_capability < 6: - self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") - cuda_okay = False - if cuda_capability >= 8: - if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") - cuda_okay = False - return super().is_compatible(verbose) and cuda_okay - - def filter_ccs(self, ccs): - ccs_retained = [] - ccs_pruned = [] - for cc in ccs: - if int(cc[0]) >= 6: - ccs_retained.append(cc) - else: - ccs_pruned.append(cc) - if len(ccs_pruned) > 0: - self.warning(f"Filtered compute capabilities {ccs_pruned}") - return ccs_retained - - def sources(self): - return [ - 'gptq/csrc/pt_binding.cpp', - 'gptq/csrc/gptq_act_linear.cu', - ] - - def extra_ldflags(self): - if not self.is_rocm_pytorch(): - return ['-lcurand', '-L/home/lcxk/data3/anaconda3/envs/triton/lib'] - else: - return [] - - def include_paths(self): - return ['gptq/csrc/includes'] - - -builder = InferenceBuilder() -inference_cuda_module = builder.load() \ No newline at end of file diff --git a/tests/test_gptq/linear_act_fusion_bench.py b/tests/test_gptq/linear_act_fusion_bench.py index 13d716412f47..d518243b0059 100644 --- a/tests/test_gptq/linear_act_fusion_bench.py +++ b/tests/test_gptq/linear_act_fusion_bench.py @@ -13,7 +13,6 @@ from colossalai.gptq.cai_gptq.gptq_op import CaiGPTQLinearOp import math import numpy as np -from colossalai.gptq import CaiInferenceConfig import csv class MLinear(nn.Module): @@ -322,9 +321,8 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize # qzeros = torch.cat((qzeros, qzeros, qzeros), dim=0).contiguous() # bias = torch.cat((bias, bias, bias), dim=0).contiguous() qkv_fused = False - cai_inf_config = CaiInferenceConfig(fp16=True) - cai_linear = CaiGPTQLinearOp(cai_inf_config) + cai_linear = CaiGPTQLinearOp(args.groupsize, args.wbits) print("cai linear") for i in range(0, warm_up_iter): diff --git a/tests/test_gptq/quant_llama.py b/tests/test_gptq/quant_llama.py index 52ac0d81f6b1..cc6d980019db 100644 --- a/tests/test_gptq/quant_llama.py +++ b/tests/test_gptq/quant_llama.py @@ -503,59 +503,60 @@ def test(model_1, model_2): safe_save(state_dict, args.save_safetensors) if args.benchmark: - # model = model.to(DEV) - # print(f"model config {model.config.num_key_value_heads}") + model = model.to(DEV) + print(f"model config {model.config.num_key_value_heads}") - # if args.model_type == "cai": - # cai_inf_config = CaiInferenceConfig(fp16=True, - # device=torch.cuda.current_device(), - # gptq=True, - # gptq_group_size=128, - # gptq_quant_bits=4) - # model = convert_to_ds_model(model, cai_inf_config) - # model.cuda().to(torch.cuda.current_device()) + if args.model_type == "cai": + cai_inf_config = CaiInferenceConfig(fp16=True, + device=torch.cuda.current_device(), + gptq=True, + gptq_group_size=128, + gptq_quant_bits=4) + model = convert_to_ds_model(model, cai_inf_config) + model.cuda().to(torch.cuda.current_device()) + benchmark(model) - torch_model = get_llama(args.model) - torch_model.half() - torch_model = torch_model.to(DEV) + # torch_model = get_llama(args.model) + # torch_model.half() + # torch_model = torch_model.to(DEV) - gptq_model = gptq_load_quant(args.model, "llama7b-4bit-128g-gptq-nao.pt", args.wbits, args.groupsize) - gptq_model = gptq_model.to(DEV) + # gptq_model = gptq_load_quant(args.model, "llama7b-4bit-128g-gptq-nao.pt", args.wbits, args.groupsize) + # gptq_model = gptq_model.to(DEV) - model = cai_load_quant(args.model, args.load, args.wbits, args.groupsize) - model = model.to(DEV) + # model = cai_load_quant(args.model, args.load, args.wbits, args.groupsize) + # model = model.to(DEV) - test(torch_model, model) - test(gptq_model, None) - - print("torch_model ", torch_model) - print("gptq_model ", gptq_model) - print("cai_model ", model) - torch_qkv_out = torch_model.model.layers[0].self_attn.qkv_out - cai_qkv_out = model.model.layers[0].self_attn.qkv_out - gptq_qkv_out = gptq_model.model.layers[0].self_attn.qkv_out - - gptq_out = gptq_model.model.layers[0].self_attn.q_proj.scales - cai_out = model.model.layers[0].self_attn.q_proj.scales - - mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) - max_diff = torch.max(torch.abs(cai_out - gptq_out)) - print("cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - for i in range(3): - cai_out = cai_qkv_out[i] - torch_out = torch_qkv_out[i] - gptq_out = gptq_qkv_out[i] - mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) - max_diff = torch.max(torch.abs(cai_out - gptq_out)) - print("cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - mean_diff = torch.mean(torch.abs(torch_out - gptq_out)) - max_diff = torch.max(torch.abs(torch_out - gptq_out)) - print("torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - mean_diff = torch.mean(torch.abs(torch_out - cai_out)) - max_diff = torch.max(torch.abs(torch_out - cai_out)) - print("torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + # test(torch_model, model) + # test(gptq_model, None) + + # print("torch_model ", torch_model) + # print("gptq_model ", gptq_model) + # print("cai_model ", model) + # torch_qkv_out = torch_model.model.layers[0].self_attn.qkv_out + # cai_qkv_out = model.model.layers[0].self_attn.qkv_out + # gptq_qkv_out = gptq_model.model.layers[0].self_attn.qkv_out + + # gptq_out = gptq_model.model.layers[0].self_attn.q_proj.scales + # cai_out = model.model.layers[0].self_attn.q_proj.scales + + # mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) + # max_diff = torch.max(torch.abs(cai_out - gptq_out)) + # print("cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + # for i in range(3): + # cai_out = cai_qkv_out[i] + # torch_out = torch_qkv_out[i] + # gptq_out = gptq_qkv_out[i] + # mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) + # max_diff = torch.max(torch.abs(cai_out - gptq_out)) + # print("cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + # mean_diff = torch.mean(torch.abs(torch_out - gptq_out)) + # max_diff = torch.max(torch.abs(torch_out - gptq_out)) + # print("torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + # mean_diff = torch.mean(torch.abs(torch_out - cai_out)) + # max_diff = torch.max(torch.abs(torch_out - cai_out)) + # print("torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) # # for batch in [1, 2, 4, 8, 16, 32]: # for batch in [1]: diff --git a/tests/test_gptq/run_gptq.sh b/tests/test_gptq/run_gptq.sh index 03aaca8f60df..cc625478de71 100644 --- a/tests/test_gptq/run_gptq.sh +++ b/tests/test_gptq/run_gptq.sh @@ -6,9 +6,9 @@ # --wbits 4 --true-sequential --groupsize 128 --save ./llama7b-4bit-128g-gptq-nao.pt\ # --benchmark --model_type gptq --input_len 1024 --max_new_tokens 128 --batch_size 1 -OMP_NUM_THREADS=48 CUDA_VISIBLE_DEVICES=5 python quant_llama.py /data/scratch/llama-7b-hf c4 \ - --wbits 4 --true-sequential --act-order --groupsize 128 --load ./llama7b-4bit-128g-cai-nao.pt\ - --benchmark --model_type cai --input_len 1024 --max_new_tokens 128 --batch_size 1 +# OMP_NUM_THREADS=48 CUDA_VISIBLE_DEVICES=5 python quant_llama.py /data/scratch/llama-7b-hf c4 \ +# --wbits 4 --true-sequential --act-order --groupsize 128 --load ./llama7b-4bit-128g-cai-nao.pt\ +# --benchmark --model_type cai --input_len 1024 --max_new_tokens 128 --batch_size 1 # OMP_NUM_THREADS=48 CUDA_VISIBLE_DEVICES=5 python quant_llama.py /data/scratch/llama-7b-hf c4 \ # --wbits 4 --true-sequential --act-order --groupsize 128 --load /llama7b-4bit-128g-gptq-nao.pt \ diff --git a/tests/test_gptq/test_linear_act_fusion.py b/tests/test_gptq/test_linear_act_fusion.py index 8079a9966843..b81388a41b43 100644 --- a/tests/test_gptq/test_linear_act_fusion.py +++ b/tests/test_gptq/test_linear_act_fusion.py @@ -13,7 +13,6 @@ from colossalai.gptq.cai_gptq.gptq_op import CaiGPTQLinearOp import math import numpy as np -from colossalai.gptq import CaiInferenceConfig import csv class MLinear(nn.Module): @@ -247,8 +246,8 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize gptq_model.to(torch.cuda.current_device()) # gptq_model = linear - cai_inf_config = CaiInferenceConfig(fp16=True) - cai_linear = CaiGPTQLinearOp(cai_inf_config) + + cai_linear = CaiGPTQLinearOp(args.groupsize, args.wbits) # qweight = torch.cat((qweight, qweight, qweight), dim=0).contiguous() @@ -312,91 +311,4 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize print("batch torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) mean_diff = torch.mean(torch.abs(batch_torch_out - batch_cai_out)) max_diff = torch.max(torch.abs(batch_torch_out - batch_cai_out)) - print("batch torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - - # print("torch time: {:.8f}, gptq time:{:.8f}, cai time: {:.8f} ".format(torch_linear_time/benchmark_iter, gptq_linear_time/benchmark_iter, cai_linear_time/benchmark_iter)) - # print("torch time: {:.8f}, gptq time:{:.8f}, cai time: {:.8f} ".format(torch_linear_time/benchmark_iter, gptq_linear_time/benchmark_iter, cai_linear_time/benchmark_iter)) - - - - - - # linear = MLinear(infeature, outfeature) - # linear.to(torch.cuda.current_device()) - # with torch.no_grad(): - # linear.linear.weight.data.copy_(weight) - # linear.linear.bias.data.copy_(bias) - # inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device()) - # quantizers = model_quant(linear, inps, torch.cuda.current_device(), args) - # qweight, qscales, qzeros = model_cai_pack(linear, quantizers, qweight, qscales, qzeros, args.wbits, args.groupsize) - # cai_inf_config = CaiInferenceConfig(fp16=True, device=torch.cuda.current_device()) - - # cai_linear = GPTQActLinearOp(cai_inf_config) - - # batch_inps = torch.randn(1, 4, infeature).to(torch.float16).to(torch.cuda.current_device()) - - # relu = nn.ReLU() - # # act_inps = relu(inps) - # # act_batch_inps = relu(batch_inps) - # # batch_inps = torch.ones(1, 2, infeature).to(torch.float16).to(torch.cuda.current_device()) - # # inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device()) - # # gptq_out = relu(inps) - # linear.to("cuda") - # with torch.no_grad(): - # torch_out = linear(inps) - # torch_batch_out = linear(batch_inps) - # # torch_out = relu(torch_out) - # # torch_batch_out = relu(torch_batch_out) - - # linear.to("cpu") - - # gptq_model = model_pack(linear, quantizers, args.wbits, args.groupsize) - # gptq_model.to(torch.cuda.current_device()) - - - # with torch.no_grad(): - # gptq_out = gptq_model(inps) - # cai_out = cai_linear(inps, - # qweight, - # qscales, - # qzeros, - # act_type = 1, - # bias = bias) - # gptq_batch_out = gptq_model(batch_inps) - # cai_batch_out = cai_linear(batch_inps, - # qweight, - # qscales, - # qzeros, - # act_type = 1, - # bias = bias) - - # torch.cuda.synchronize() - # # gptq_out = relu(gptq_out) - # # re_gptq_batch_out = relu(gptq_batch_out) - # # print(f"cai out {cai_out}") - # # mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) - # # max_diff = torch.max(torch.abs(cai_out - gptq_out)) - # # print("cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - # # mean_diff = torch.mean(torch.abs(torch_out - gptq_out)) - # # max_diff = torch.max(torch.abs(torch_out - gptq_out)) - # # print("torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - # # mean_diff = torch.mean(torch.abs(torch_out - cai_out)) - # # max_diff = torch.max(torch.abs(torch_out - cai_out)) - # # print("torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - - # print(f"cai batch out {cai_batch_out}") - # print(f"gptq batch out {gptq_batch_out}") - # print(f"torch batch out {torch_batch_out}") - # # print(f"gptq batch out {re_gptq_batch_out}") - - # mean_diff = torch.mean(torch.abs(cai_batch_out - gptq_batch_out)) - # max_diff = torch.max(torch.abs(cai_batch_out - gptq_batch_out)) - # print("cai vs gptq batch 128: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - # mean_diff = torch.mean(torch.abs(torch_batch_out - gptq_batch_out)) - # max_diff = torch.max(torch.abs(torch_batch_out - gptq_batch_out)) - # print("torch vs gptq batch 128: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - # mean_diff = torch.mean(torch.abs(torch_batch_out - cai_batch_out)) - # max_diff = torch.max(torch.abs(torch_batch_out - cai_batch_out)) - # print("torch vs cai batch 128: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - - + print("batch torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) \ No newline at end of file From a56d61b9f9e58f063b8598aac43567a5484cc5f0 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Fri, 11 Aug 2023 17:00:37 +0800 Subject: [PATCH 03/15] fix tests --- requirements/requirements.txt | 1 + tests/test_gptq/test_quant_llama.py | 530 ---------------------------- 2 files changed, 1 insertion(+), 530 deletions(-) delete mode 100644 tests/test_gptq/test_quant_llama.py diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 65eecce2c34f..eece233e4e48 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -12,3 +12,4 @@ torch>=1.11 safetensors flash_attn>=2.0 einops +texttable diff --git a/tests/test_gptq/test_quant_llama.py b/tests/test_gptq/test_quant_llama.py deleted file mode 100644 index 9f73a116f5bf..000000000000 --- a/tests/test_gptq/test_quant_llama.py +++ /dev/null @@ -1,530 +0,0 @@ -import argparse -import time -import numpy as np -import torch -import torch.nn as nn -from colossalai.gptq.gptq_utils import quant -from colossalai.gptq import cai_gptq - -from colossalai.gptq.gptq_utils import GPTQ, Observer -from colossalai.gptq.gptq_utils.utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders, export_quant_table, gen_conditions -from texttable import Texttable -from colossalai.gptq import CaiInferenceConfig -from transformers import LlamaForCausalLM, LlamaTokenizer - -import csv - -def get_llama(model): - - def skip(*args, **kwargs): - pass - - torch.nn.init.kaiming_uniform_ = skip - torch.nn.init.uniform_ = skip - torch.nn.init.normal_ = skip - from transformers import LlamaForCausalLM, LlamaConfig, LlamaModel - if args.debug: - llama_kwargs= {"bos_token_id": 0, - "eos_token_id": 1, - "hidden_act": "silu", - "hidden_size": 4096, - "initializer_range": 0.02, - "intermediate_size": 11008, - "max_position_embeddings": 2048, - "max_sequence_length": 2048, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 1, - "pad_token_id": -1, - "rms_norm_eps": 1e-06, - "tie_word_embeddings": False, - "torch_dtype": "float16", - "use_cache": True, - "vocab_size": 32000 - } - configuration = LlamaConfig( **llama_kwargs - ) - model = LlamaForCausalLM(configuration) - else: - model = LlamaForCausalLM.from_pretrained(model, torch_dtype=torch.float16) - - # # LlamaForCausalLM - model.seqlen = 2048 - return model - - -@torch.no_grad() -def llama_sequential(model, dataloader, dev): - print('Starting ...') - - use_cache = model.config.use_cache - model.config.use_cache = False - layers = model.model.layers - - model.model.embed_tokens = model.model.embed_tokens.to(dev) - model.model.norm = model.model.norm.to(dev) - layers[0] = layers[0].to(dev) - - dtype = next(iter(model.parameters())).dtype - inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) - cache = {'i': 0, 'attention_mask': None} - - class Catcher(nn.Module): - - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self, inp, **kwargs): - inps[cache['i']] = inp - cache['i'] += 1 - cache['attention_mask'] = kwargs['attention_mask'] - cache['position_ids'] = kwargs['position_ids'] - raise ValueError - - layers[0] = Catcher(layers[0]) - for batch in dataloader: - try: - model(batch[0].to(dev)) - except ValueError: - pass - layers[0] = layers[0].module - - layers[0] = layers[0].cpu() - model.model.embed_tokens = model.model.embed_tokens.cpu() - model.model.norm = model.model.norm.cpu() - torch.cuda.empty_cache() - - outs = torch.zeros_like(inps) - attention_mask = cache['attention_mask'] - position_ids = cache['position_ids'] - - print('Ready.') - - quantizers = {} - observer = Observer() - for i in range(len(layers)): - - print(f'Quantizing layer {i+1}/{len(layers)}..') - print('+------------------+--------------+------------+-----------+-------+') - print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |') - print('+==================+==============+============+===========+=======+') - - layer = layers[i].to(dev) - full = find_layers(layer) - if args.true_sequential: - sequential = [['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], ['self_attn.o_proj'], ['mlp.up_proj', 'mlp.gate_proj'], ['mlp.down_proj']] - else: - sequential = [list(full.keys())] - - for names in sequential: - subset = {n: full[n] for n in names} - gptq = {} - for name in subset: - gptq[name] = GPTQ(subset[name], observe=args.observe) - gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False) - - def add_batch(name): - - def tmp(_, inp, out): - gptq[name].add_batch(inp[0].data, out.data) - - return tmp - - handles = [] - for name in subset: - handles.append(subset[name].register_forward_hook(add_batch(name))) - for j in range(args.nsamples): - outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] - for h in handles: - h.remove() - - for name in subset: - scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, name=name) - quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), args.wbits, args.groupsize) - - if args.observe: - observer.submit(name=name, layerid=i, gptq=gptq[name], error=error) - else: - gptq[name].free() - - for j in range(args.nsamples): - outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] - - layers[i] = layer.cpu() - del layer - del gptq - torch.cuda.empty_cache() - - inps, outs = outs, inps - print('+------------------+--------------+------------+-----------+-------+') - print('\n') - - if args.observe: - observer.print() - conditions = gen_conditions(args.wbits, args.groupsize) - for item in observer.items(): - name = item[0] - layerid = item[1] - gptq = item[2]['gptq'] - error = item[2]['error'] - target = error / 2 - - table = Texttable() - table.header(['wbits', 'groupsize', 'error']) - table.set_cols_dtype(['i', 'i', 'f']) - table.add_row([args.wbits, args.groupsize, error]) - - print('Optimizing {} {} ..'.format(name, layerid)) - for wbits, groupsize in conditions: - - if error < target: - # if error dropped 50%, skip - break - - gptq.quantizer.configure(wbits, perchannel=True, sym=args.sym, mse=False) - - scale, zero, g_idx, error = gptq.fasterquant(percdamp=args.percdamp, groupsize=groupsize, actorder=args.act_order, name=name) - - table.add_row([wbits, groupsize, error]) - quantizers['model.layers.%d.%s' % (layerid, name)] = (gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), wbits, groupsize) - - print(table.draw()) - print('\n') - gptq.layer.to('cpu') - gptq.free() - - model.config.use_cache = use_cache - - return quantizers - - -# TODO: perform packing on GPU -def cai_llama_pack(model, quantizers, wbits, groupsize): - layers = find_layers(model) - # print(f"model {model}") - # print(f"layers {layers}") - - layers = {n: layers[n] for n in quantizers} - # print(f"quantizers {quantizers}") - cai_gptq.make_cai_quant_linear(model, quantizers, wbits, groupsize) - qlayers = find_layers(model, [cai_gptq.CaiQuantLinear]) - print('Packing ...') - for name in qlayers: - print(name) - quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] - qlayers[name].pack(layers[name], scale, zero, g_idx) - print('Done.') - return model - -def gptq_llama_pack(model, quantizers, wbits, groupsize): - layers = find_layers(model) - # print(f"model {model}") - # print(f"layers {layers}") - - layers = {n: layers[n] for n in quantizers} - # print(f"quantizers {quantizers}") - quant.make_quant_linear(model, quantizers, wbits, groupsize) - qlayers = find_layers(model, [quant.QuantLinear]) - print('Packing ...') - for name in qlayers: - print(name) - quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] - qlayers[name].pack(layers[name], scale, zero, g_idx) - print('Done.') - return model - - -def cai_load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True): - from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils - config = LlamaConfig.from_pretrained(model) - - def noop(*args, **kwargs): - pass - - torch.nn.init.kaiming_uniform_ = noop - torch.nn.init.uniform_ = noop - torch.nn.init.normal_ = noop - - torch.set_default_dtype(torch.half) - modeling_utils._init_weights = False - torch.set_default_dtype(torch.half) - model = LlamaForCausalLM(config) - torch.set_default_dtype(torch.float) - if eval: - model = model.eval() - layers = find_layers(model) - for name in ['lm_head']: - if name in layers: - del layers[name] - cai_gptq.make_cai_quant_linear(model, layers, wbits, groupsize) - - del layers - - print('Loading model ...') - if checkpoint.endswith('.safetensors'): - from safetensors.torch import load_file as safe_load - model.load_state_dict(safe_load(checkpoint)) - else: - model.load_state_dict(torch.load(checkpoint)) - - print('Done.') - - return model - - -def gptq_load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True): - from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils - config = LlamaConfig.from_pretrained(model) - - def noop(*args, **kwargs): - pass - - torch.nn.init.kaiming_uniform_ = noop - torch.nn.init.uniform_ = noop - torch.nn.init.normal_ = noop - - torch.set_default_dtype(torch.half) - modeling_utils._init_weights = False - torch.set_default_dtype(torch.half) - model = LlamaForCausalLM(config) - torch.set_default_dtype(torch.float) - if eval: - model = model.eval() - layers = find_layers(model) - for name in ['lm_head']: - if name in layers: - del layers[name] - quant.make_quant_linear(model, layers, wbits, groupsize) - - del layers - - print('Loading model ...') - if checkpoint.endswith('.safetensors'): - from safetensors.torch import load_file as safe_load - model.load_state_dict(safe_load(checkpoint)) - else: - model.load_state_dict(torch.load(checkpoint)) - - print('Done.') - - return model - -all_perfs = [] -now_perf=[] - -def print_perf_stats(latency_set, config, warmup=3): - global now_perf - # trim warmup queries - latency_set = list(latency_set) - latency_set = latency_set[warmup:] - count = len(latency_set) - - if count > 0: - latency_set.sort() - avg = sum(latency_set) / count - num_layers = getattr(config, "num_layers", config.num_hidden_layers) - num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 - num_bytes = 2 - # if args.dtype == "float16": - # num_bytes = 2 - # elif args.dtype == "float32": - # num_bytes = 4 - # else: - # num_bytes = 1 - print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) - print("Avg BW: {0:8.2f} GB/s".format(1/avg * num_parameters * num_bytes / 1e9)) - print("Avg flops: {0:8.2f} TFlops/s".format(1/avg * num_parameters * num_bytes * args.batch_size / 1e12)) - print("Alloc GPU Mem: {0:8.2f} GB".format(torch.cuda.memory_allocated() / 1e9)) - print("Max alloc GPU Mem: {0:8.2f} GB".format(torch.cuda.max_memory_allocated()/1e9)) - row = [args.batch_size, args.input_len, args.max_new_tokens, "{0:8.2f}".format(avg * 1000), - "{0:8.2f}".format(torch.cuda.memory_allocated() / 1e9), - "{0:8.2f}".format(torch.cuda.max_memory_allocated()/1e9)] - with open('./{}_profile.csv'.format(args.model_type), 'a', encoding='UTF8') as f: - # create the csv writer - writer = csv.writer(f) - - # write a row to the csv file - writer.writerow(row) - - now_perf.append("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) - now_perf.append("Alloc GPU Mem: {0:8.2f} GB".format(torch.cuda.memory_allocated() / 1e9)) - now_perf.append("Max alloc GPU Mem: {0:8.2f} GB".format(torch.cuda.max_memory_allocated()/1e9)) - - all_perfs.append(now_perf) - now_perf = [] - -def benchmark(model): - - input_tokens = {"input_ids":torch.randint(1, 1000, (args.batch_size, args.input_len), device=DEV), - "attention_mask":torch.ones((args.batch_size, args.input_len), device=DEV)} - torch.cuda.synchronize() - iters = 10 if args.benchmark else 2 #warmup - print(f"model config {model.config}") - - times = [] - warmup=3 - prof_flag = 0 - generate_kwargs = dict(max_new_tokens=args.max_new_tokens, do_sample=False) - torch.cuda.reset_peak_memory_stats() - for i in range(iters): - if i >= warmup: - prof_flag=1 - torch.cuda.synchronize() - start = time.time() - outputs = model.generate(**input_tokens, - **generate_kwargs) - torch.cuda.synchronize() - end = time.time() - times.append(end - start) - print("outpus shape: ", outputs.shape) - print(args) - print("input batch, input len, out len: ",args.batch_size, args.input_len, args.max_new_tokens) - # if args.local_rank == 0: - now_perf.extend(["input batch, input len, out len: ",args.batch_size, args.input_len, args.max_new_tokens]) - print_perf_stats(map(lambda t: t / args.max_new_tokens, times), model.config) - -def test(model_1, model_2): - # input_tokens = {"input_ids":torch.randint(1, 1000, (args.batch_size, args.input_len), device=DEV), - # "attention_mask":torch.ones((args.batch_size, args.input_len), device=DEV)} - generate_kwargs = dict(max_new_tokens=args.max_new_tokens, do_sample=False) - - - tokenizer = LlamaTokenizer.from_pretrained(args.model) - tokenizer.pad_token_id = tokenizer.unk_token_id - - text = "how is weather today? I want to know the weather of beijing. " - text = "how are you?" - - inputs = [text] - input_tokens = tokenizer.batch_encode_plus(inputs, padding = True, return_tensors="pt") - - input_len = 0 - for t in input_tokens: - if torch.is_tensor(input_tokens[t]): - input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) - # print(input_tokens[t].shape) - input_len = input_tokens[t].shape[1] - - outputs_1 = model_1.generate(**input_tokens, - **generate_kwargs) - - - outputs_2 = model_2.generate(**input_tokens, - **generate_kwargs) - - out_1 = tokenizer.batch_decode(outputs_1) - out_2 = tokenizer.batch_decode(outputs_2) - - ret = torch.allclose(outputs_1, outputs_2) - print("allclose is ", ret) - print("decode out:", out_1) - print("decode out:", out_2) - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - - parser.add_argument('model', type=str, help='llama model to load') - parser.add_argument('dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.') - parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.') - parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration data samples.') - parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.') - parser.add_argument('--nearest', action='store_true', help='Whether to run the RTN baseline.') - parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.') - parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.') - parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.') - parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.') - parser.add_argument('--save_safetensors', type=str, default='', help='Save quantized `.safetensors` checkpoint under this name.') - parser.add_argument('--load', type=str, default='', help='Load quantized model.') - parser.add_argument('--benchmark', action='store_true', help='Number of tokens to use for benchmarking.') - parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.') - parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.') - parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic') - parser.add_argument('--true-sequential', action='store_true', help='Whether to run in true sequential model.') - parser.add_argument('--layers-dist', type=str, default='', help='Distribution of layers across GPUs. e.g. 2:1:1 for 2 layers on GPU 0, 1 layer on GPU 1, and 1 layer on GPU 2. Any remaining layers will be assigned to your last GPU.') - parser.add_argument('--observe', - action='store_true', - help='Auto upgrade layer precision to higher precision, for example int2 to int4, groupsize 128 to 64. \ - When this feature enabled, `--save` or `--save_safetensors` would be disable.') - parser.add_argument('--quant-directory', type=str, default=None, help='Specify the directory for export quantization parameters to toml format. `None` means no export by default.') - parser.add_argument('--max_new_tokens', type=int, default=32, help='Max new tokens to generate.') - parser.add_argument('--batch_size', type=int, default=1, help='Batch size to generate.') - parser.add_argument('--input_len', type=int, default=128, help='Batch size to generate.') - parser.add_argument('--model_type', type=str, choices=['cai', 'gptq', 'torch'], default='torch', help='Batch size to generate.') - parser.add_argument('--debug', action='store_true', help='Whether to debug or not') - - args = parser.parse_args() - - model_packed = False - if type(args.load) is not str: - args.load = args.load.as_posix() - - if args.load: - if args.model_type == "gptq": - model = gptq_load_quant(args.model, args.load, args.wbits, args.groupsize) - elif args.model_type == "cai": - model = cai_load_quant(args.model, args.load, args.wbits, args.groupsize) - else: - model = get_llama(args.model) - model.half() - - if not args.load and args.wbits < 16 and not args.nearest and args.model_type in ['cai', 'gptq']: - dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen) - tick = time.time() - quantizers = llama_sequential(model, dataloader, DEV) - if args.model_type == "cai": - cai_llama_pack(model, quantizers, args.wbits, args.groupsize) - elif args.model_type == "gptq": - gptq_llama_pack(model, quantizers, args.wbits, args.groupsize) - model_packed = True - print(time.time() - tick) - - - if args.quant_directory is not None: - export_quant_table(quantizers, args.quant_directory) - - if not args.observe and args.save and args.model_type in ['cai', 'gptq']: - if not model_packed: - llama_pack(model, quantizers, args.wbits, args.groupsize) - model_packed = True - torch.save(model.state_dict(), args.save) - - if not args.observe and args.save_safetensors and args.model_type in ['cai', 'gptq']: - if not model_packed: - llama_pack(model, quantizers, args.wbits, args.groupsize) - from safetensors.torch import save_file as safe_save - state_dict = model.state_dict() - state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} - safe_save(state_dict, args.save_safetensors) - - if args.benchmark: - model = model.to(DEV) - # print(f"model config {model.config.num_key_value_heads}") - - # if args.model_type == "cai": - # cai_inf_config = CaiInferenceConfig(fp16=True, - # device=torch.cuda.current_device(), - # gptq=True, - # gptq_group_size=128, - # gptq_quant_bits=4) - # model = convert_to_ds_model(model, cai_inf_config) - # model.cuda().to(torch.cuda.current_device()) - - - torch_model = get_llama(args.model) - torch_model.half() - torch_model = torch_model.to(DEV) - - test(torch_model, model) - - # # for batch in [1, 2, 4, 8, 16, 32]: - # for batch in [1]: - # args.batch_size = batch - # # for in_len in [128, 256, 512, 1024, 2048]: - # for in_len in [1024]: - # args.input_len = in_len - # benchmark(model) - # # for info in all_perfs: - # # print(info) - # # # all_perfs = [] \ No newline at end of file From ef97b74730e84656c1c67bb967b1c530ef8599ab Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 16 Aug 2023 17:24:03 +0800 Subject: [PATCH 04/15] replace auto-gptq --- colossalai/gptq/cai_gptq/cai_quant_linear.py | 21 +- colossalai/gptq/cai_gptq/gptq_op.py | 5 +- colossalai/gptq/cai_gptq/gptq_triton.py | 21 +- colossalai/gptq/gptq_utils/__init__.py | 1 - colossalai/gptq/gptq_utils/gptq.py | 236 -------- colossalai/gptq/gptq_utils/quant/__init__.py | 5 - .../gptq/gptq_utils/quant/custom_autotune.py | 194 ------ .../gptq/gptq_utils/quant/fused_attn.py | 204 ------- colossalai/gptq/gptq_utils/quant/fused_mlp.py | 288 --------- .../gptq/gptq_utils/quant/quant_linear.py | 422 ------------- colossalai/gptq/gptq_utils/quant/quantizer.py | 127 ---- .../gptq/gptq_utils/quant/triton_norm.py | 92 --- colossalai/gptq/gptq_utils/utils/__init__.py | 3 - colossalai/gptq/gptq_utils/utils/datautils.py | 193 ------ colossalai/gptq/gptq_utils/utils/export.py | 37 -- .../gptq/gptq_utils/utils/modelutils.py | 83 --- .../ops/gptq}/linear_act_fusion_bench.py | 65 +- requirements/requirements.txt | 1 + tests/test_gptq/quant_llama.py | 570 ------------------ tests/test_gptq/run_gptq.sh | 19 - tests/test_gptq/test_linear_act_fusion.py | 75 +-- 21 files changed, 98 insertions(+), 2564 deletions(-) delete mode 100644 colossalai/gptq/gptq_utils/__init__.py delete mode 100644 colossalai/gptq/gptq_utils/gptq.py delete mode 100644 colossalai/gptq/gptq_utils/quant/__init__.py delete mode 100644 colossalai/gptq/gptq_utils/quant/custom_autotune.py delete mode 100644 colossalai/gptq/gptq_utils/quant/fused_attn.py delete mode 100644 colossalai/gptq/gptq_utils/quant/fused_mlp.py delete mode 100644 colossalai/gptq/gptq_utils/quant/quant_linear.py delete mode 100644 colossalai/gptq/gptq_utils/quant/quantizer.py delete mode 100644 colossalai/gptq/gptq_utils/quant/triton_norm.py delete mode 100644 colossalai/gptq/gptq_utils/utils/__init__.py delete mode 100644 colossalai/gptq/gptq_utils/utils/datautils.py delete mode 100644 colossalai/gptq/gptq_utils/utils/export.py delete mode 100644 colossalai/gptq/gptq_utils/utils/modelutils.py rename {tests/test_gptq => examples/ops/gptq}/linear_act_fusion_bench.py (88%) delete mode 100644 tests/test_gptq/quant_llama.py delete mode 100644 tests/test_gptq/run_gptq.sh diff --git a/colossalai/gptq/cai_gptq/cai_quant_linear.py b/colossalai/gptq/cai_gptq/cai_quant_linear.py index 72a8e6d5607c..737b24462dc4 100644 --- a/colossalai/gptq/cai_gptq/cai_quant_linear.py +++ b/colossalai/gptq/cai_gptq/cai_quant_linear.py @@ -3,7 +3,6 @@ import numpy as np import torch import torch.nn as nn -from torch.cuda.amp import custom_bwd, custom_fwd from .gptq_op import CaiGPTQLinearOp import triton @@ -22,10 +21,7 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias): self.register_buffer('qweight', torch.zeros((infeatures // 64 * self.bits, outfeatures), dtype=torch.int64)) self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 64 * self.bits), dtype=torch.int64)) self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) - # self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int64)) - # self.order_qzeros = torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int64) - # self.register_buffer('input_idx', torch.zeros(infeatures], dtype=torch.int32)) - + self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) if bias: self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) @@ -33,11 +29,10 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias): self.bias = None self.gptq_linear = CaiGPTQLinearOp(groupsize, bits) - self.printed = False - self.reorder_zeros = False - def pack(self, linear, scales, zeros, g_idx=None): + def pack(self, linear, scales, zeros, g_idx=None): + g_idx = g_idx.clone() if g_idx is not None else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) scales = scales.t().contiguous() @@ -103,8 +98,13 @@ def pack(self, linear, scales, zeros, g_idx=None): raise NotImplementedError("Only 2,4,8 bits are supported.") qzeros = qzeros.astype(sign_type) qzeros = torch.from_numpy(qzeros) - qzeros = qzeros #.to(torch.cuda.current_device()) + qzeros = qzeros self.qzeros.data.copy_(qzeros) + + if torch.equal(self.g_idx, g_idx): + self.g_idx = None + else: + self.g_idx = g_idx def forward(self, x): @@ -113,7 +113,8 @@ def forward(self, x): self.qweight, self.scales, self.qzeros, - bias = self.bias) + g_idx = self.g_idx, + bias = self.bias,) return cai_out def make_cai_quant_linear(module, names, bits, groupsize, name=''): diff --git a/colossalai/gptq/cai_gptq/gptq_op.py b/colossalai/gptq/cai_gptq/gptq_op.py index 7ada87055a97..aca1cb5b87c5 100644 --- a/colossalai/gptq/cai_gptq/gptq_op.py +++ b/colossalai/gptq/cai_gptq/gptq_op.py @@ -16,6 +16,7 @@ def forward(self, weight: torch.Tensor, weight_scales: torch.Tensor, weight_zeros: torch.Tensor, + g_idx: torch.Tensor = None, act_type = 0, bias: torch.Tensor = None, residual: torch.Tensor=None, @@ -32,9 +33,9 @@ def forward(self, add_residual = False x = input.view(-1, input.shape[-1]) - out = gptq_fused_linear_triton(x, weight, weight_scales, weight_zeros, bias, residual, - self.bits, self.maxq, self.group_size, qkv_fused, add_bias, add_residual, act_type=act_type) + self.bits, self.maxq, self.group_size, qkv_fused, add_bias, add_residual, + act_type=act_type, g_idx=g_idx) if qkv_fused: out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1]) else: diff --git a/colossalai/gptq/cai_gptq/gptq_triton.py b/colossalai/gptq/cai_gptq/gptq_triton.py index eb77987ef625..8a505ebad73f 100644 --- a/colossalai/gptq/cai_gptq/gptq_triton.py +++ b/colossalai/gptq/cai_gptq/gptq_triton.py @@ -1,7 +1,7 @@ import triton import triton.language as tl import torch -from ..gptq_utils.quant import custom_autotune +from auto_gptq.nn_modules.triton_utils import custom_autotune # from ..ops.triton.kernels.activations_kernels import relu, gelu, silu # code based https://github.com/fpgaminer/GPTQ-triton # triton.Config({ @@ -221,11 +221,11 @@ def cai_gptq_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, bias_ for k in range(0, num_pid_k): # g_idx = tl.load(g_ptrs) - if (k + 1) * BLOCK_SIZE_K > currend_group_end: - scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) + # if (k + 1) * BLOCK_SIZE_K > currend_group_end: + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated @@ -391,8 +391,7 @@ def cai_gptq_idx_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, i for k in range(0, num_pid_k): # g_idx = tl.load(g_ptrs) - if (k + 1) * BLOCK_SIZE_K > currend_group_end: - scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = (zeros >> zeros_shifter[None, :]) & maxq @@ -438,7 +437,7 @@ def cai_gptq_idx_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, i def gptq_fused_linear_triton(input, qweight, scales, qzeros, bias, residual, - bits, maxq, gptq_group_size, qkv_fused, add_bias, add_residual, idx = None, act_type = 0): + bits, maxq, gptq_group_size, qkv_fused, add_bias, add_residual, g_idx = None, act_type = 0): # print("gptq fused ", qkv_fused, add_bias, add_residual) with torch.cuda.device(input.device): if qkv_fused: @@ -448,7 +447,7 @@ def gptq_fused_linear_triton(input, qweight, scales, qzeros, bias, residual, grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), ) output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) # print("dtype, ", qweight.dtype, output.dtype, scales.dtype, qzeros.dtype, bias.dtype, residual.dtype) - if idx is None: + if g_idx is None: cai_gptq_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, bias, residual, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, gptq_group_size, @@ -456,7 +455,7 @@ def gptq_fused_linear_triton(input, qweight, scales, qzeros, bias, residual, qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0), QKV_FUSED=qkv_fused, ADD_BIAS=add_bias, ADD_RESIDUAL=add_residual, ACT_TYPE=act_type) else: - cai_gptq_idx_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, idx, bias, residual, + cai_gptq_idx_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, bias, residual, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, gptq_group_size, input.stride(0), input.stride(1), qweight.stride(0), diff --git a/colossalai/gptq/gptq_utils/__init__.py b/colossalai/gptq/gptq_utils/__init__.py deleted file mode 100644 index 2b9db6637df3..000000000000 --- a/colossalai/gptq/gptq_utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .gptq import GPTQ, Observer \ No newline at end of file diff --git a/colossalai/gptq/gptq_utils/gptq.py b/colossalai/gptq/gptq_utils/gptq.py deleted file mode 100644 index e17c0c47c6d7..000000000000 --- a/colossalai/gptq/gptq_utils/gptq.py +++ /dev/null @@ -1,236 +0,0 @@ -import math -import time - -import torch -import torch.nn as nn -import transformers -from .quant import Quantizer -from texttable import Texttable -from .utils import torch_snr_error - -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False - - -class Observer: - - def __init__(self, topk=32): - self.loss_list = [] - self.topk = topk - - def submit(self, name: str, layerid: int, gptq, error: float): - - item = (name, layerid, {'gptq': gptq, 'error': error}) - - if len(self.loss_list) < self.topk: - self.loss_list.append(item) - return - - min_error = error - min_idx = -1 - for idx, data in enumerate(self.loss_list): - if min_error > data[2]['error']: - min_idx = idx - min_error = data[2]['error'] - - if min_idx >= 0: - self.loss_list[min_idx] = item - - def print(self): - self.loss_list = sorted(self.loss_list, key=lambda s: s[2]['error'], reverse=True) - - table = Texttable() - - table.header(['name', 'error']) - table.set_cols_dtype(['t', 'f']) - - for item in self.loss_list: - table.add_row([f"{item[0]}.{item[1]}", item[2]['error']]) - print(table.draw()) - print('\n') - - def items(self): - return self.loss_list - - -class GPTQ: - - def __init__(self, layer, observe=False): - self.layer = layer - self.dev = self.layer.weight.device - W = layer.weight.data.clone() - if isinstance(self.layer, nn.Conv2d): - W = W.flatten(1) - if isinstance(self.layer, transformers.Conv1D): - W = W.t() - self.rows = W.shape[0] - self.columns = W.shape[1] - self.H = torch.zeros((self.columns, self.columns), device=self.dev) - self.nsamples = 0 - self.quantizer = Quantizer() - self.observe = observe - - def add_batch(self, inp, out): - # Hessian H = 2 X XT + λ I - if self.observe: - self.inp1 = inp - self.out1 = out - else: - self.inp1 = None - self.out1 = None - - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - tmp = inp.shape[0] - if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) - inp = inp.t() - if isinstance(self.layer, nn.Conv2d): - unfold = nn.Unfold(self.layer.kernel_size, dilation=self.layer.dilation, padding=self.layer.padding, stride=self.layer.stride) - inp = unfold(inp) - inp = inp.permute([1, 0, 2]) - inp = inp.flatten(1) - self.H *= self.nsamples / (self.nsamples + tmp) - self.nsamples += tmp - # inp = inp.float() - inp = math.sqrt(2 / self.nsamples) * inp.float() - # self.H += 2 / self.nsamples * inp.matmul(inp.t()) - self.H += inp.matmul(inp.t()) - - def print_loss(self, name, q_weight, weight_error, timecost): - table = Texttable() - name += ' ' * (16 - len(name)) - - table.header(['name', 'weight_error', 'fp_inp_SNR', 'q_inp_SNR', 'time']) - - # assign weight - self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) - - if self.inp1 is not None: - # quantize input to int8 - quantizer = Quantizer() - quantizer.configure(8, perchannel=False, sym=True, mse=False) - quantizer.find_params(self.inp1) - q_in = quantizer.quantize(self.inp1).type(torch.float16) - q_out = self.layer(q_in) - - # get kinds of SNR - q_SNR = torch_snr_error(q_out, self.out1).item() - fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item() - else: - q_SNR = '-' - fp_SNR = '-' - - table.add_row([name, weight_error, fp_SNR, q_SNR, timecost]) - print(table.draw().split('\n')[-2]) - - def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, name=''): - self.layer.to(self.dev) - - W = self.layer.weight.data.clone() - if isinstance(self.layer, nn.Conv2d): - W = W.flatten(1) - if isinstance(self.layer, transformers.Conv1D): - W = W.t() - W = W.float() - - tick = time.time() - - if not self.quantizer.ready(): - self.quantizer.find_params(W, weight=True) - - H = self.H - if not self.observe: - del self.H - dead = torch.diag(H) == 0 - H[dead, dead] = 1 - W[:, dead] = 0 - - if actorder: - perm = torch.argsort(torch.diag(H), descending=True) - W = W[:, perm] - H = H[perm][:, perm] - - Losses = torch.zeros_like(W) - Q = torch.zeros_like(W) - - damp = percdamp * torch.mean(torch.diag(H)) - diag = torch.arange(self.columns, device=self.dev) - H[diag, diag] += damp - H = torch.linalg.cholesky(H) - H = torch.cholesky_inverse(H) - H = torch.linalg.cholesky(H, upper=True) - Hinv = H - - g_idx = [] - scale = [] - zero = [] - now_idx = 1 - - for i1 in range(0, self.columns, blocksize): - i2 = min(i1 + blocksize, self.columns) - count = i2 - i1 - - W1 = W[:, i1:i2].clone() - Q1 = torch.zeros_like(W1) - Err1 = torch.zeros_like(W1) - Losses1 = torch.zeros_like(W1) - Hinv1 = Hinv[i1:i2, i1:i2] - - for i in range(count): - w = W1[:, i] - d = Hinv1[i, i] - - if groupsize != -1: - if (i1 + i) % groupsize == 0: - self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True) - - if ((i1 + i) // groupsize) - now_idx == -1: - scale.append(self.quantizer.scale) - zero.append(self.quantizer.zero) - now_idx += 1 - - q = self.quantizer.quantize(w.unsqueeze(1)).flatten() - Q1[:, i] = q - Losses1[:, i] = (w - q)**2 / d**2 - - err1 = (w - q) / d - W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - Err1[:, i] = err1 - - Q[:, i1:i2] = Q1 - Losses[:, i1:i2] = Losses1 / 2 - - W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) - - torch.cuda.synchronize() - error = torch.sum(Losses).item() - - groupsize = groupsize if groupsize != -1 else self.columns - g_idx = [i // groupsize for i in range(self.columns)] - g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) - if actorder: - invperm = torch.argsort(perm) - Q = Q[:, invperm] - g_idx = g_idx[invperm] - - if isinstance(self.layer, transformers.Conv1D): - Q = Q.t() - - self.print_loss(name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)) - - if scale == []: - scale.append(self.quantizer.scale) - zero.append(self.quantizer.zero) - scale = torch.cat(scale, dim=1) - zero = torch.cat(zero, dim=1) - return scale, zero, g_idx, error - - def free(self): - self.inp1 = None - self.out1 = None - self.H = None - self.Losses = None - self.Trace = None - torch.cuda.empty_cache() diff --git a/colossalai/gptq/gptq_utils/quant/__init__.py b/colossalai/gptq/gptq_utils/quant/__init__.py deleted file mode 100644 index 64452784656b..000000000000 --- a/colossalai/gptq/gptq_utils/quant/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .quantizer import Quantizer -from .fused_attn import QuantLlamaAttention, make_quant_attn -from .fused_mlp import QuantLlamaMLP, make_fused_mlp, autotune_warmup_fused -from .quant_linear import QuantLinear, make_quant_linear, autotune_warmup_linear -from .triton_norm import TritonLlamaRMSNorm, make_quant_norm diff --git a/colossalai/gptq/gptq_utils/quant/custom_autotune.py b/colossalai/gptq/gptq_utils/quant/custom_autotune.py deleted file mode 100644 index 286cf5d08586..000000000000 --- a/colossalai/gptq/gptq_utils/quant/custom_autotune.py +++ /dev/null @@ -1,194 +0,0 @@ -#https://github.com/fpgaminer/GPTQ-triton -""" -Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. -""" - -import builtins -import math -import time -from typing import Dict - -import triton - - -class Autotuner(triton.KernelInterface): - - def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False): - ''' - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. - 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results - ''' - if not configs: - self.configs = [triton.Config({}, num_warps=4, num_stages=2)] - else: - self.configs = configs - self.key_idx = [arg_names.index(k) for k in key] - self.nearest_power_of_two = nearest_power_of_two - self.cache = {} - # hook to reset all required tensor to zeros before relaunching a kernel - self.hook = lambda args: 0 - if reset_to_zero is not None: - self.reset_idx = [arg_names.index(k) for k in reset_to_zero] - - def _hook(args): - for i in self.reset_idx: - args[i].zero_() - - self.hook = _hook - self.arg_names = arg_names - # prune configs - if prune_configs_by: - perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] - if 'early_config_prune' in prune_configs_by: - early_config_prune = prune_configs_by['early_config_prune'] - else: - perf_model, top_k, early_config_prune = None, None, None - self.perf_model, self.configs_top_k = perf_model, top_k - self.early_config_prune = early_config_prune - self.fn = fn - - def _bench(self, *args, config, **meta): - # check for conflicts, i.e. meta-parameters both provided - # as kwargs and by the autotuner - conflicts = meta.keys() & config.kwargs.keys() - if conflicts: - raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." - " Make sure that you don't re-define auto-tuned symbols.") - # augment meta-parameters with tunable ones - current = dict(meta, **config.kwargs) - - def kernel_call(): - if config.pre_hook: - config.pre_hook(self.nargs) - self.hook(args) - self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) - return triton.testing.do_bench(kernel_call, rep=40) - - # try: - # # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses - # # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default - # return triton.testing.do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40) - # except triton.compiler.OutOfResources: - # return (float('inf'), float('inf'), float('inf')) - - def run(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - if len(self.configs) > 1: - key = tuple(args[i] for i in self.key_idx) - - # This reduces the amount of autotuning by rounding the keys to the nearest power of two - # In my testing this gives decent results, and greatly reduces the amount of tuning required - if self.nearest_power_of_two: - key = tuple([2**int(math.log2(x) + 0.5) for x in key]) - - if key not in self.cache: - # prune configs - pruned_configs = self.prune_configs(kwargs) - bench_start = time.time() - timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} - bench_end = time.time() - self.bench_time = bench_end - bench_start - self.cache[key] = builtins.min(timings, key=timings.get) - self.hook(args) - self.configs_timings = timings - config = self.cache[key] - else: - config = self.configs[0] - self.best_config = config - if config.pre_hook is not None: - config.pre_hook(self.nargs) - return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) - - def prune_configs(self, kwargs): - pruned_configs = self.configs - if self.early_config_prune: - pruned_configs = self.early_config_prune(self.configs, self.nargs) - if self.perf_model: - top_k = self.configs_top_k - if isinstance(top_k, float) and top_k <= 1.0: - top_k = int(len(self.configs) * top_k) - if len(pruned_configs) > top_k: - est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} - pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] - return pruned_configs - - def warmup(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - for config in self.prune_configs(kwargs): - self.fn.warmup( - *args, - num_warps=config.num_warps, - num_stages=config.num_stages, - **kwargs, - **config.kwargs, - ) - self.nargs = None - - -def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False): - """ - Decorator for auto-tuning a :code:`triton.jit`'d function. - .. highlight:: python - .. code-block:: python - @triton.autotune(configs=[ - triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), - triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), - ], - key=['x_size'] # the two above configs will be evaluated anytime - # the value of x_size changes - ) - @triton.jit - def kernel(x_ptr, x_size, **META): - BLOCK_SIZE = META['BLOCK_SIZE'] - :note: When all the configurations are evaluated, the kernel will run multiple time. - This means that whatever value the kernel updates will be updated multiple times. - To avoid this undesired behavior, you can use the `reset_to_zero` argument, which - reset the value of the provided tensor to `zero` before running any configuration. - :param configs: a list of :code:`triton.Config` objects - :type configs: list[triton.Config] - :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. - :type key: list[str] - :param prune_configs_by: a dict of functions that are used to prune configs, fields: - 'perf_model': performance model used to predicate running time with different configs, returns running time - 'top_k': number of configs to bench - 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. - :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. - :type reset_to_zero: list[str] - """ - - def decorator(fn): - return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two) - - return decorator - - -def matmul248_kernel_config_pruner(configs, nargs): - """ - The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. - """ - m = max(2**int(math.ceil(math.log2(nargs['M']))), 16) - n = max(2**int(math.ceil(math.log2(nargs['N']))), 16) - k = max(2**int(math.ceil(math.log2(nargs['K']))), 16) - - used = set() - for config in configs: - block_size_m = min(m, config.kwargs['BLOCK_SIZE_M']) - block_size_n = min(n, config.kwargs['BLOCK_SIZE_N']) - block_size_k = min(k, config.kwargs['BLOCK_SIZE_K']) - group_size_m = config.kwargs['GROUP_SIZE_M'] - - if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used: - continue - - used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps)) - yield triton.Config({ - 'BLOCK_SIZE_M': block_size_m, - 'BLOCK_SIZE_N': block_size_n, - 'BLOCK_SIZE_K': block_size_k, - 'GROUP_SIZE_M': group_size_m - }, - num_stages=config.num_stages, - num_warps=config.num_warps) diff --git a/colossalai/gptq/gptq_utils/quant/fused_attn.py b/colossalai/gptq/gptq_utils/quant/fused_attn.py deleted file mode 100644 index 2076c2cc37b2..000000000000 --- a/colossalai/gptq/gptq_utils/quant/fused_attn.py +++ /dev/null @@ -1,204 +0,0 @@ -from torch.nn import functional as F -from transformers.models.llama.modeling_llama import LlamaAttention -from .quant_linear import * -import triton -import triton.language as tl - - -@triton.jit -def rotate_half_kernel( - qk_seq_ptr, - position_ids_ptr, - qk_seq_stride, - position_ids_batch_stride, - seq_len, - HEAD_DIM: tl.constexpr, - BLOCK_HEIGHT: tl.constexpr, - BLOCK_WIDTH: tl.constexpr, - INV_BASE: tl.constexpr -): - # qk_seq_ptr: (bsz, seq_len, 2, num_heads, head_dim) -- OK to be discontinuous in 2nd dimension. - # position ids: (bsz, seq_len) -- must be contiguous in the last dimension. - - HALF_HEAD: tl.constexpr = HEAD_DIM // 2 - STEPS_PER_ROW: tl.constexpr = HALF_HEAD // BLOCK_WIDTH - - batch_seq = tl.program_id(axis=0) - row_blk_x_col_blk = tl.program_id(axis=1) - - row_blk = row_blk_x_col_blk // STEPS_PER_ROW - row = row_blk * BLOCK_HEIGHT - if BLOCK_WIDTH < HALF_HEAD: - col_blk = row_blk_x_col_blk % STEPS_PER_ROW - col = col_blk * BLOCK_WIDTH - else: - col: tl.constexpr = 0 - - # A block will never cross a sequence boundary, which simplifies things a lot. - batch = batch_seq // seq_len - seq = batch_seq % seq_len - position_id = tl.load(position_ids_ptr + batch * position_ids_batch_stride + seq) - # As sometimes happens, just calculating this on the fly is faster than loading it from memory. - # Use `tl.libdevice.exp` rather than `tl.exp` -- the latter is less accurate. - freq = tl.libdevice.exp((col + tl.arange(0, BLOCK_WIDTH)).to(tl.float32) * INV_BASE) * position_id - cos = tl.cos(freq).to(tl.float32) - sin = tl.sin(freq).to(tl.float32) - - col_offsets: tl.constexpr = tl.arange(0, BLOCK_WIDTH) - embed_offsets = (row * HEAD_DIM + col) + col_offsets - x_ptrs = (qk_seq_ptr + batch_seq * qk_seq_stride) + embed_offsets - - for k in range(0, BLOCK_HEIGHT): - x = tl.load(x_ptrs).to(tl.float32) - y = tl.load(x_ptrs + HALF_HEAD).to(tl.float32) - out_x = x * cos - y * sin - tl.store(x_ptrs, out_x) - out_y = x * sin + y * cos - tl.store(x_ptrs + HALF_HEAD, out_y) - x_ptrs += HEAD_DIM - - -def triton_rotate_half_(qk, position_ids, config=None): - with torch.cuda.device(qk.device): - batch_size, seq_len, qandk, num_heads, head_dim = qk.shape - - # This default is the fastest for most job sizes, at least on my RTX 4090, and when it's not it's within spitting distance of the best option. There are some odd cases where having a block height of 2 or 4 helps but the difference is within 5%. It makes sense that this configuration is fast from a memory bandwidth and caching perspective. - config = config or {'BLOCK_HEIGHT': 1, 'BLOCK_WIDTH': min(128, head_dim // 2), 'num_warps': 1} - config['BLOCK_HEIGHT'] = min(config['BLOCK_HEIGHT'], 2 * num_heads) - - assert qk.stride(3) == head_dim - assert qk.stride(4) == 1 - assert position_ids.shape == (batch_size, seq_len) - assert position_ids.stride(1) == 1, 'position_ids must be contiguous in the last dimension' - assert (2 * num_heads) % config['BLOCK_HEIGHT'] == 0, f'number of rows not evenly divisible by {config["BLOCK_HEIGHT"]}' - assert (head_dim // 2) % config['BLOCK_WIDTH'] == 0, f'number of columns ({head_dim // 2}) not evenly divisible by {config["BLOCK_WIDTH"]}' - - qk_by_seq = qk.view(batch_size * seq_len, 2 * num_heads * head_dim) - grid = (qk_by_seq.shape[0], (2 * num_heads // config['BLOCK_HEIGHT']) * (head_dim // 2 // config['BLOCK_WIDTH'])) - - # Must be the same as the theta of the frequencies used to train the model. - BASE = 10000.0 - - rotate_half_kernel[grid]( - qk_by_seq, - position_ids, - qk_by_seq.stride(0), - position_ids.stride(0), - seq_len, - HEAD_DIM=head_dim, - BLOCK_HEIGHT=config['BLOCK_HEIGHT'], - BLOCK_WIDTH=config['BLOCK_WIDTH'], - INV_BASE=-2.0 * math.log(BASE) / head_dim, - num_warps=config['num_warps'] - ) - - -class QuantLlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - hidden_size, - num_heads, - qkv_proj, - o_proj - ): - super().__init__() - self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = hidden_size // num_heads - - if (self.head_dim * num_heads) != self.hidden_size: - raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {num_heads}).") - self.qkv_proj = qkv_proj - self.o_proj = o_proj - - def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False): - """Input shape: Batch x Time x Channel""" - - bsz, q_len, _ = hidden_states.size() - - qkv_states = self.qkv_proj(hidden_states) - qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim) - - # This updates the query and key states in-place, saving VRAM. - triton_rotate_half_(qkv_states[:, :, :2], position_ids) - - query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2) - del qkv_states - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - is_causal = past_key_value is None - - kv_seq_len = q_len - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - if use_cache: - # Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor - # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this. - key_states = key_states.contiguous() - value_states = value_states.contiguous() - query_states = query_states.contiguous() - - past_key_value = (key_states, value_states) if use_cache else None - - with torch.backends.cuda.sdp_kernel(enable_math=False): - attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal) - del query_states, key_states, value_states - - attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -def make_quant_attn(model): - """ - Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. - """ - - for name, m in model.named_modules(): - if not isinstance(m, LlamaAttention): - continue - - q_proj = m.q_proj - k_proj = m.k_proj - v_proj = m.v_proj - - qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) - qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) - scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) - g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) - bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None - - qkv_layer = QuantLinear(q_proj.bits, q_proj.groupsize, q_proj.infeatures, q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, True if q_proj.bias is not None else False) - qkv_layer.qweight = qweights - qkv_layer.qzeros = qzeros - qkv_layer.scales = scales - qkv_layer.g_idx = g_idx - qkv_layer.bias = bias - # We're dropping the rotary embedding layer m.rotary_emb here. We don't need it in the triton branch. - - attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj) - - if '.' in name: - parent_name = name.rsplit('.', 1)[0] - child_name = name[len(parent_name) + 1:] - parent = model.get_submodule(parent_name) - else: - parent_name = '' - parent = model - child_name = name - - #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") - - setattr(parent, child_name, attn) diff --git a/colossalai/gptq/gptq_utils/quant/fused_mlp.py b/colossalai/gptq/gptq_utils/quant/fused_mlp.py deleted file mode 100644 index a5e402e38f94..000000000000 --- a/colossalai/gptq/gptq_utils/quant/fused_mlp.py +++ /dev/null @@ -1,288 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -from torch.cuda.amp import custom_bwd, custom_fwd -from transformers.models.llama.modeling_llama import LlamaMLP - -try: - import triton - import triton.language as tl - from . import custom_autotune - - # code based https://github.com/fpgaminer/GPTQ-triton - @custom_autotune.autotune( - configs=[ - triton.Config({ - 'BLOCK_SIZE_M': 256, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), # 3090 - triton.Config({ - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 16, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), # 3090 - triton.Config({ - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 128, - 'GROUP_SIZE_M': 8 - }, num_stages=2, num_warps=4), # 3090 - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 16, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), # 3090 - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), # 3090 - ], - key=['M', 'N', 'K'], - nearest_power_of_two=True, - prune_configs_by={ - 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, - 'perf_model': None, - 'top_k': None, - }, - ) - @triton.jit - def fusedmatmul_248_kernel(a_ptr, c_ptr, b1_ptr, scales1_ptr, zeros1_ptr, g1_ptr, b2_ptr, scales2_ptr, zeros2_ptr, g2_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, - stride_cm, stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): - """ - Computes: C = silu(A * B1) * (A * B2) - A is of shape (M, K) float16 - B is of shape (K//8, N) int32 - C is of shape (M, N) float16 - scales is of shape (1, N) float16 - zeros is of shape (1, N//8) int32 - """ - infearure_per_bits = 32 // bits - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = (offs_am[:, None] < M) - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) - b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) - g1_ptrs = g1_ptr + offs_k - g2_ptrs = g2_ptr + offs_k - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales1_ptrs = scales1_ptr + offs_bn[None, :] - scales2_ptrs = scales2_ptr + offs_bn[None, :] - zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits) - zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits) - - shifter = (offs_k % infearure_per_bits) * bits - zeros_shifter = (offs_bn % infearure_per_bits) * bits - accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, num_pid_k): - g1_idx = tl.load(g1_ptrs) - g2_idx = tl.load(g2_ptrs) - - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales) - - zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq - zeros1 = (zeros1 + 1) - - zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq - zeros2 = (zeros2 + 1) - - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - b2 = tl.load(b2_ptrs) - - # Now we need to unpack b (which is N-bit values) into 32-bit values - b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values - b1 = (b1 - zeros1) * scales1 # Scale and shift - accumulator1 += tl.dot(a, b1) - - b2 = (b2 >> shifter[:, None]) & maxq - b2 = (b2 - zeros2) * scales2 - accumulator2 += tl.dot(a, b2) - - a_ptrs += BLOCK_SIZE_K - b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - g1_ptrs += BLOCK_SIZE_K - g2_ptrs += BLOCK_SIZE_K - - accumulator1 = silu(accumulator1) - c = accumulator1 * accumulator2 - c = c.to(tl.float16) - c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) - - @triton.jit - def silu(x): - return x * tl.sigmoid(x) -except: - print('triton not installed.') - - -class QuantLlamaMLP(nn.Module): - - def __init__( - self, - gate_proj, - down_proj, - up_proj, - ): - super().__init__() - self.register_buffer('gate_proj_qweight', gate_proj.qweight) - self.register_buffer('gate_proj_scales', gate_proj.scales) - self.register_buffer('gate_proj_qzeros', gate_proj.qzeros) - self.register_buffer('gate_proj_g_idx', gate_proj.g_idx) - self.register_buffer('up_proj_qweight', up_proj.qweight) - self.register_buffer('up_proj_scales', up_proj.scales) - self.register_buffer('up_proj_qzeros', up_proj.qzeros) - self.register_buffer('up_proj_g_idx', up_proj.g_idx) - - self.infeatures = gate_proj.infeatures - self.intermediate_size = gate_proj.outfeatures - self.outfeatures = down_proj.outfeatures - self.bits = gate_proj.bits - self.maxq = gate_proj.maxq - - self.down_proj = down_proj - - def forward(self, x): - return self.down_proj(self.triton_llama_mlp(x)) - - def triton_llama_mlp(self, x): - with torch.cuda.device(x.device): - out_shape = x.shape[:-1] + (self.intermediate_size, ) - x = x.reshape(-1, x.shape[-1]) - M, K = x.shape - N = self.intermediate_size - c = torch.empty((M, N), device=x.device, dtype=torch.float16) - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) - fusedmatmul_248_kernel[grid](x, c, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, self.gate_proj_g_idx, self.up_proj_qweight, self.up_proj_scales, - self.up_proj_qzeros, self.up_proj_g_idx, M, N, K, self.bits, self.maxq, x.stride(0), x.stride(1), self.gate_proj_qweight.stride(0), - self.gate_proj_qweight.stride(1), c.stride(0), c.stride(1), self.gate_proj_scales.stride(0), self.gate_proj_qzeros.stride(0)) - c = c.reshape(out_shape) - return c - - def fused2cuda(self): - self.gate_proj_qweight = self.gate_proj_qweight.cuda() - self.gate_proj_scales = self.gate_proj_scales.cuda() - self.gate_proj_qzeros = self.gate_proj_qzeros.cuda() - self.gate_proj_g_idx = self.gate_proj_g_idx.cuda() - self.up_proj_qweight = self.up_proj_qweight.cuda() - self.up_proj_scales = self.up_proj_scales.cuda() - self.up_proj_qzeros = self.up_proj_qzeros.cuda() - self.up_proj_g_idx = self.up_proj_g_idx.cuda() - - def fused2cpu(self): - self.gate_proj_qweight = self.gate_proj_qweight.cpu() - self.gate_proj_scales = self.gate_proj_scales.cpu() - self.gate_proj_qzeros = self.gate_proj_qzeros.cpu() - self.gate_proj_g_idx = self.gate_proj_g_idx.cpu() - self.up_proj_qweight = self.up_proj_qweight.cpu() - self.up_proj_scales = self.up_proj_scales.cpu() - self.up_proj_qzeros = self.up_proj_qzeros.cpu() - self.up_proj_g_idx = self.up_proj_g_idx.cpu() - - -def make_fused_mlp(m, parent_name=''): - """ - Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations. - """ - if isinstance(m, LlamaMLP): - return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj) - - for name, child in m.named_children(): - child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}") - - if isinstance(child, QuantLlamaMLP): - setattr(m, name, child) - return m - - -def autotune_warmup_fused(model): - """ - Pre-tunes the quantized kernel - """ - from tqdm import tqdm - - kn_values = {} - - for _, m in model.named_modules(): - if not isinstance(m, QuantLlamaMLP): - continue - - k = m.infeatures - n = m.intermediate_size - - m.fused2cuda() - if (k, n) not in kn_values: - kn_values[(k, n)] = m - - print(f'Found {len(kn_values)} unique fused mlp KN values.') - - print('Warming up autotune cache ...') - with torch.no_grad(): - for m in tqdm(range(0, 12)): - m = 2**m # [1, 2048] - for (k, n), (modules) in kn_values.items(): - a = torch.randn(m, k, dtype=torch.float16, device='cuda') - modules.triton_llama_mlp(a) - - for (k, n), (modules) in kn_values.items(): - a = torch.randn(m, k, dtype=torch.float16, device='cuda') - modules.fused2cpu() - del kn_values diff --git a/colossalai/gptq/gptq_utils/quant/quant_linear.py b/colossalai/gptq/gptq_utils/quant/quant_linear.py deleted file mode 100644 index 5144a962a928..000000000000 --- a/colossalai/gptq/gptq_utils/quant/quant_linear.py +++ /dev/null @@ -1,422 +0,0 @@ -import math -import numpy as np -import torch -import torch.nn as nn -from torch.cuda.amp import custom_bwd, custom_fwd - -try: - import triton - import triton.language as tl - from . import custom_autotune - - # code based https://github.com/fpgaminer/GPTQ-triton - @custom_autotune.autotune( - configs=[ - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=2, num_warps=8), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 8 - }, num_stages=3, num_warps=8), - triton.Config({ - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 128, - 'GROUP_SIZE_M': 8 - }, num_stages=2, num_warps=4), - ], - key=['M', 'N', 'K'], - nearest_power_of_two=True, - prune_configs_by={ - 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, - 'perf_model': None, - 'top_k': None, - }, - ) - @triton.jit - def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (K//8, N) int32 - C is of shape (M, N) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 - """ - infearure_per_bits = 32 // bits - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = (offs_am[:, None] < M) - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - g_ptrs = g_ptr + offs_k - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + offs_bn[None, :] - zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) - - shifter = (offs_k % infearure_per_bits) * bits - zeros_shifter = (offs_bn % infearure_per_bits) * bits - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in range(0, num_pid_k): - g_idx = tl.load(g_ptrs) - - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) - - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - - # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros) * scales # Scale and shift - - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K - b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - g_ptrs += BLOCK_SIZE_K - - c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - - @custom_autotune.autotune(configs=[ - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 256, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 128, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 128, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 8 - }, num_stages=4, num_warps=4), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 128, - 'GROUP_SIZE_M': 8 - }, num_stages=2, num_warps=8), - triton.Config({ - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 8 - }, num_stages=3, num_warps=8), - triton.Config({ - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - }, num_stages=2, num_warps=4), - ], - key=['M', 'N', 'K'], - nearest_power_of_two=True) - @triton.jit - def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, - stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, N) float16 - B is of shape (K//8, N) int32 - C is of shape (M, K) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N) float16 - g_ptr is of shape (K) int32 - """ - infearure_per_bits = 32 // bits - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_k - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_k = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - offs_n = tl.arange(0, BLOCK_SIZE_N) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N) - a_mask = (offs_am[:, None] < M) - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - g_ptrs = g_ptr + offs_bk - g_idx = tl.load(g_ptrs) - - # shifter is used to extract the N bits of each element in the 32-bit word from B - scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales - zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros - - shifter = (offs_bk % infearure_per_bits) * bits - zeros_shifter = (offs_n % infearure_per_bits) * bits - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) - - for n in range(0, num_pid_n): - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) - - zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = (zeros + 1) - - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - - # Now we need to unpack b (which is N-bit values) into 32-bit values - b = (b >> shifter[:, None]) & maxq # Extract the N-bit values - b = (b - zeros) * scales # Scale and shift - b = tl.trans(b) - - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_N - b_ptrs += BLOCK_SIZE_N - scales_ptrs += BLOCK_SIZE_N - zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits) - - c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] - c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) - tl.store(c_ptrs, accumulator, mask=c_mask) -except: - print('triton not installed.') - - -def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): - with torch.cuda.device(input.device): - output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) - grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), ) - matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), - qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) - return output - - -def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): - with torch.cuda.device(input.device): - output_dim = (qweight.shape[0] * 32) // bits - output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16) - grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), ) - transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), - qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) - return output - - -class QuantLinearFunction(torch.autograd.Function): - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): - output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) - ctx.save_for_backward(qweight, scales, qzeros, g_idx) - ctx.bits, ctx.maxq = bits, maxq - return output - - @staticmethod - @custom_bwd - def backward(ctx, grad_output): - qweight, scales, qzeros, g_idx = ctx.saved_tensors - bits, maxq = ctx.bits, ctx.maxq - grad_input = None - - if ctx.needs_input_grad[0]: - grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq) - return grad_input, None, None, None, None, None, None - - -class QuantLinear(nn.Module): - - def __init__(self, bits, groupsize, infeatures, outfeatures, bias): - super().__init__() - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2,4,8 bits are supported.") - self.infeatures = infeatures - self.outfeatures = outfeatures - self.bits = bits - self.maxq = 2**self.bits - 1 - self.groupsize = groupsize if groupsize != -1 else infeatures - - self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) - self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) - self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) - self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) - if bias: - self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) - else: - self.bias = None - - def pack(self, linear, scales, zeros, g_idx=None): - self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx - - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() - scale_zeros = zeros * scales - self.scales = scales.clone().half() - if linear.bias is not None: - self.bias = linear.bias.clone().half() - - intweight = [] - for idx in range(self.infeatures): - intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:, None]) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(np.uint32) - qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) - i = 0 - row = 0 - while row < qweight.shape[0]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qweight[row] |= intweight[j] << (self.bits * (j - i)) - i += 32 // self.bits - row += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qweight = qweight.astype(np.int32) - self.qweight = torch.from_numpy(qweight) - - zeros -= 1 - zeros = zeros.numpy().astype(np.uint32) - qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) - i = 0 - col = 0 - while col < qzeros.shape[1]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) - i += 32 // self.bits - col += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - - qzeros = qzeros.astype(np.int32) - self.qzeros = torch.from_numpy(qzeros) - - def forward(self, x): - out_shape = x.shape[:-1] + (self.outfeatures, ) - out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq) - out = out + self.bias if self.bias is not None else out - return out.reshape(out_shape) - -def make_quant_linear(module, names, bits, groupsize, name=''): - if isinstance(module, QuantLinear): - return - for attr in dir(module): - tmp = getattr(module, attr) - name1 = name + '.' + attr if name != '' else attr - if name1 in names: - delattr(module, attr) - setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None)) - for name1, child in module.named_children(): - make_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) - - -def autotune_warmup_linear(model, transpose=False): - """ - Pre-tunes the quantized kernel - """ - from tqdm import tqdm - - kn_values = {} - - for _, m in model.named_modules(): - if not isinstance(m, QuantLinear): - continue - - k = m.infeatures - n = m.outfeatures - - if (k, n) not in kn_values: - kn_values[(k, n)] = (m.qweight.cuda(), m.scales.cuda(), m.qzeros.cuda(), m.g_idx.cuda(), m.bits, m.maxq) - - print(f'Found {len(kn_values)} unique KN Linear values.') - - print('Warming up autotune cache ...') - with torch.no_grad(): - for m in tqdm(range(0, 12)): - m = 2**m # [1, 2048] - for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items(): - a = torch.randn(m, k, dtype=torch.float16, device='cuda') - matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq) - if transpose: - a = torch.randn(m, n, dtype=torch.float16, device='cuda') - transpose_matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq) - del kn_values diff --git a/colossalai/gptq/gptq_utils/quant/quantizer.py b/colossalai/gptq/gptq_utils/quant/quantizer.py deleted file mode 100644 index 76844b8769aa..000000000000 --- a/colossalai/gptq/gptq_utils/quant/quantizer.py +++ /dev/null @@ -1,127 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -import math - - -class Quantizer(nn.Module): - - def __init__(self, shape=1): - super(Quantizer, self).__init__() - self.register_buffer('maxq', torch.tensor(0)) - self.register_buffer('scale', torch.zeros(shape)) - self.register_buffer('zero', torch.zeros(shape)) - - def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False): - - self.maxq = torch.tensor(2**bits - 1) - self.perchannel = perchannel - self.sym = sym - self.mse = mse - self.norm = norm - self.grid = grid - self.maxshrink = maxshrink - if trits: - self.maxq = torch.tensor(-1) - self.scale = torch.zeros_like(self.scale) - - def _quantize(self, x, scale, zero, maxq): - if maxq < 0: - return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero - q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) - return scale * (q - zero) - - def find_params(self, x, weight=False): - dev = x.device - self.maxq = self.maxq.to(dev) - - shape = x.shape - if self.perchannel: - if weight: - x = x.flatten(1) - else: - if len(shape) == 4: - x = x.permute([1, 0, 2, 3]) - x = x.flatten(1) - if len(shape) == 3: - x = x.reshape((-1, shape[-1])).t() - if len(shape) == 2: - x = x.t() - else: - x = x.flatten().unsqueeze(0) - - tmp = torch.zeros(x.shape[0], device=dev) - xmin = torch.minimum(x.min(1)[0], tmp) - xmax = torch.maximum(x.max(1)[0], tmp) - - if self.sym: - xmax = torch.maximum(torch.abs(xmin), xmax) - tmp = xmin < 0 - if torch.any(tmp): - xmin[tmp] = -xmax[tmp] - tmp = (xmin == 0) & (xmax == 0) - xmin[tmp] = -1 - xmax[tmp] = +1 - - if self.maxq < 0: - self.scale = xmax - self.zero = xmin - else: - self.scale = (xmax - xmin) / self.maxq - if self.sym: - self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) - else: - self.zero = torch.round(-xmin / self.scale) - - if self.mse: - best = torch.full([x.shape[0]], float('inf'), device=dev) - for i in range(int(self.maxshrink * self.grid)): - p = 1 - i / self.grid - xmin1 = p * xmin - xmax1 = p * xmax - scale1 = (xmax1 - xmin1) / self.maxq - zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero - q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) - q -= x - q.abs_() - q.pow_(self.norm) - err = torch.sum(q, 1) - tmp = err < best - if torch.any(tmp): - best[tmp] = err[tmp] - self.scale[tmp] = scale1[tmp] - self.zero[tmp] = zero1[tmp] - if not self.perchannel: - if weight: - tmp = shape[0] - else: - tmp = shape[1] if len(shape) != 3 else shape[2] - self.scale = self.scale.repeat(tmp) - self.zero = self.zero.repeat(tmp) - - if weight: - shape = [-1] + [1] * (len(shape) - 1) - self.scale = self.scale.reshape(shape) - self.zero = self.zero.reshape(shape) - return - if len(shape) == 4: - self.scale = self.scale.reshape((1, -1, 1, 1)) - self.zero = self.zero.reshape((1, -1, 1, 1)) - if len(shape) == 3: - self.scale = self.scale.reshape((1, 1, -1)) - self.zero = self.zero.reshape((1, 1, -1)) - if len(shape) == 2: - self.scale = self.scale.unsqueeze(0) - self.zero = self.zero.unsqueeze(0) - - def quantize(self, x): - if self.ready(): - return self._quantize(x, self.scale, self.zero, self.maxq) - - return x - - def enabled(self): - return self.maxq > 0 - - def ready(self): - return torch.all(self.scale != 0) diff --git a/colossalai/gptq/gptq_utils/quant/triton_norm.py b/colossalai/gptq/gptq_utils/quant/triton_norm.py deleted file mode 100644 index 1e3228a18d51..000000000000 --- a/colossalai/gptq/gptq_utils/quant/triton_norm.py +++ /dev/null @@ -1,92 +0,0 @@ -import torch -from torch import nn -import triton -import triton.language as tl -from transformers.models.llama.modeling_llama import LlamaRMSNorm - -@triton.jit -def rms_norm_fwd_fused( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - stride, # how much to increase the pointer when moving by 1 row - N, # number of columns in X - eps, # epsilon to avoid division by zero - BLOCK_SIZE: tl.constexpr, -): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - Y += row * stride - X += row * stride - # Compute variance - _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) - x = tl.where(cols < N, x, 0.) - _var += x * x - var = tl.sum(_var, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - # Normalize and apply linear transformation - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - mask = cols < N - w = tl.load(W + cols, mask=mask) - x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) - x_hat = x * rstd - y = x_hat * w - # Write output - tl.store(Y + cols, y, mask=mask) - -class TritonLlamaRMSNorm(nn.Module): - def __init__(self, weight, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = weight - self.variance_epsilon = eps - - def forward(self, x): - with torch.cuda.device(x.device): - y = torch.empty_like(x) - # reshape input data into 2D tensor - x_arg = x.reshape(-1, x.shape[-1]) - M, N = x_arg.shape - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - if N > BLOCK_SIZE: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - # heuristics for number of warps - num_warps = min(max(BLOCK_SIZE // 256, 1), 8) - # enqueue kernel - rms_norm_fwd_fused[(M,)](x_arg, y, self.weight, - x_arg.stride(0), N, self.variance_epsilon, - BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) - return y - - -def make_quant_norm(model): - """ - Replace all LlamaRMSNorm modules with TritonLlamaRMSNorm modules - """ - - for name, m in model.named_modules(): - if not isinstance(m, LlamaRMSNorm): - continue - - norm = TritonLlamaRMSNorm(m.weight, m.variance_epsilon) - - if '.' in name: - parent_name = name.rsplit('.', 1)[0] - child_name = name[len(parent_name) + 1:] - parent = model.get_submodule(parent_name) - else: - parent_name = '' - parent = model - child_name = name - - #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") - - setattr(parent, child_name, norm) diff --git a/colossalai/gptq/gptq_utils/utils/__init__.py b/colossalai/gptq/gptq_utils/utils/__init__.py deleted file mode 100644 index cf1741216f79..000000000000 --- a/colossalai/gptq/gptq_utils/utils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .modelutils import DEV, find_layers, gen_conditions, torch_snr_error -from .datautils import set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders -from .export import export_quant_table diff --git a/colossalai/gptq/gptq_utils/utils/datautils.py b/colossalai/gptq/gptq_utils/utils/datautils.py deleted file mode 100644 index 10a3a43d3ef5..000000000000 --- a/colossalai/gptq/gptq_utils/utils/datautils.py +++ /dev/null @@ -1,193 +0,0 @@ -import numpy as np -import torch - - -def set_seed(seed): - np.random.seed(seed) - torch.random.manual_seed(seed) - - -def get_wikitext2(nsamples, seed, seqlen, model): - from datasets import load_dataset - traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') - testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') - - from transformers import AutoTokenizer - try: - tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) - except: - tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) - trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') - testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') - - import random - random.seed(seed) - trainloader = [] - for _ in range(nsamples): - i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = trainenc.input_ids[:, i:j] - tar = inp.clone() - tar[:, :-1] = -100 - trainloader.append((inp, tar)) - return trainloader, testenc - - -def get_ptb(nsamples, seed, seqlen, model): - from datasets import load_dataset - traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') - valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation') - - from transformers import AutoTokenizer - try: - tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) - except: - tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) - trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') - testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') - - import random - random.seed(seed) - trainloader = [] - for _ in range(nsamples): - i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = trainenc.input_ids[:, i:j] - tar = inp.clone() - tar[:, :-1] = -100 - trainloader.append((inp, tar)) - return trainloader, testenc - - -def get_c4(nsamples, seed, seqlen, model): - from datasets import load_dataset - traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False) - valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', use_auth_token=False) - - from transformers import AutoTokenizer - try: - if "llama" in model: - from transformers import LlamaTokenizer - tokenizer = LlamaTokenizer.from_pretrained(model, use_fast=False) - else: - tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) - except: - tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) - - import random - random.seed(seed) - trainloader = [] - for _ in range(nsamples): - while True: - i = random.randint(0, len(traindata) - 1) - trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') - if trainenc.input_ids.shape[1] >= seqlen: - break - i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = trainenc.input_ids[:, i:j] - tar = inp.clone() - tar[:, :-1] = -100 - trainloader.append((inp, tar)) - - import random - random.seed(0) - valenc = [] - for _ in range(256): - while True: - i = random.randint(0, len(valdata) - 1) - tmp = tokenizer(valdata[i]['text'], return_tensors='pt') - if tmp.input_ids.shape[1] >= seqlen: - break - i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - valenc.append(tmp.input_ids[:, i:j]) - valenc = torch.hstack(valenc) - - class TokenizerWrapper: - - def __init__(self, input_ids): - self.input_ids = input_ids - - valenc = TokenizerWrapper(valenc) - - return trainloader, valenc - - -def get_ptb_new(nsamples, seed, seqlen, model): - from datasets import load_dataset - traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') - testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') - - from transformers import AutoTokenizer - try: - tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) - except: - tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) - trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') - testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') - - import random - random.seed(seed) - trainloader = [] - for _ in range(nsamples): - i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = trainenc.input_ids[:, i:j] - tar = inp.clone() - tar[:, :-1] = -100 - trainloader.append((inp, tar)) - return trainloader, testenc - - -def get_c4_new(nsamples, seed, seqlen, model): - from datasets import load_dataset - traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') - valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') - - from transformers import AutoTokenizer - try: - tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) - except: - tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) - - import random - random.seed(seed) - trainloader = [] - for _ in range(nsamples): - while True: - i = random.randint(0, len(traindata) - 1) - trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') - if trainenc.input_ids.shape[1] >= seqlen: - break - i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = trainenc.input_ids[:, i:j] - tar = inp.clone() - tar[:, :-1] = -100 - trainloader.append((inp, tar)) - - valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') - valenc = valenc.input_ids[:, :(256 * seqlen)] - - class TokenizerWrapper: - - def __init__(self, input_ids): - self.input_ids = input_ids - - valenc = TokenizerWrapper(valenc) - - return trainloader, valenc - - -def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''): - if 'wikitext2' in name: - return get_wikitext2(nsamples, seed, seqlen, model) - if 'ptb' in name: - if 'new' in name: - return get_ptb_new(nsamples, seed, seqlen, model) - return get_ptb(nsamples, seed, seqlen, model) - if 'c4' in name: - if 'new' in name: - return get_c4_new(nsamples, seed, seqlen, model) - return get_c4(nsamples, seed, seqlen, model) diff --git a/colossalai/gptq/gptq_utils/utils/export.py b/colossalai/gptq/gptq_utils/utils/export.py deleted file mode 100644 index a623afcf49b5..000000000000 --- a/colossalai/gptq/gptq_utils/utils/export.py +++ /dev/null @@ -1,37 +0,0 @@ -import numpy as np -import toml -import os - - -def export_quant_table(quantizers: dict, quant_dir: str, format: str = 'toml'): - - table = {} - - def save_tensor(name: str, tensor): - np.save(os.path.join(quant_dir, name), tensor.numpy()) - return '{}.npy'.format(name) - - for key, value in quantizers.items(): - quantizer = value[0] - - dump = dict() - - sym = quantizer.sym - if not sym: - dump['zero'] = save_tensor(name=key + '.zero', tensor=value[2]) - dump['scale'] = save_tensor(name=key + '.scale', tensor=value[1]) - dump['wbits'] = value[4] - dump['groupsize'] = value[5] - if value[5] > 0: - dump['group_ids'] = save_tensor(name=key + '.group_ids', tensor=value[3]) - - dump['sym'] = sym - dump['perchannel'] = quantizer.perchannel - - table[key] = dump - - if not os.path.exists(quant_dir): - os.mkdir(quant_dir) - - with open(os.path.join(quant_dir, 'quant.toml'), 'w') as f: - toml.dump(table, f) diff --git a/colossalai/gptq/gptq_utils/utils/modelutils.py b/colossalai/gptq/gptq_utils/utils/modelutils.py deleted file mode 100644 index d043cca02b7d..000000000000 --- a/colossalai/gptq/gptq_utils/utils/modelutils.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.nn as nn - -DEV = torch.device('cuda:0') - - -def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): - if type(module) in layers: - return {name: module} - res = {} - for name1, child in module.named_children(): - res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1)) - return res - - -def gen_conditions(_wbits, _groupsize): - wbits = _wbits - groupsize = _groupsize - conditions = [] - while True: - if wbits >= 8: - if groupsize == -1 or groupsize == 32: - break - - if groupsize > 32: - groupsize /= 2 - else: - wbits *= 2 - groupsize = _groupsize - - conditions.append((int(wbits), int(groupsize))) - return conditions - - -# copy from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py -def torch_snr_error(y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = 'mean') -> torch.Tensor: - """ - Compute SNR between y_pred(tensor) and y_real(tensor) - - SNR can be calcualted as following equation: - - SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2 - - if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements. - - SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2) - Args: - y_pred (torch.Tensor): _description_ - y_real (torch.Tensor): _description_ - reduction (str, optional): _description_. Defaults to 'mean'. - Raises: - ValueError: _description_ - ValueError: _description_ - Returns: - torch.Tensor: _description_ - """ - y_pred = y_pred.type(torch.float32) - y_real = y_real.type(torch.float32) - - if y_pred.shape != y_real.shape: - raise ValueError(f'Can not compute snr loss for tensors with different shape. ' - f'({y_pred.shape} and {y_real.shape})') - reduction = str(reduction).lower() - - if y_pred.ndim == 1: - y_pred = y_pred.unsqueeze(0) - y_real = y_real.unsqueeze(0) - - y_pred = y_pred.flatten(start_dim=1) - y_real = y_real.flatten(start_dim=1) - - noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1) - signal_power = torch.pow(y_real, 2).sum(dim=-1) - snr = (noise_power) / (signal_power + 1e-7) - - if reduction == 'mean': - return torch.mean(snr) - elif reduction == 'sum': - return torch.sum(snr) - elif reduction == 'none': - return snr - else: - raise ValueError(f'Unsupported reduction method.') diff --git a/tests/test_gptq/linear_act_fusion_bench.py b/examples/ops/gptq/linear_act_fusion_bench.py similarity index 88% rename from tests/test_gptq/linear_act_fusion_bench.py rename to examples/ops/gptq/linear_act_fusion_bench.py index d518243b0059..c4ecf5cb5e18 100644 --- a/tests/test_gptq/linear_act_fusion_bench.py +++ b/examples/ops/gptq/linear_act_fusion_bench.py @@ -1,19 +1,20 @@ - import torch import torch.nn as nn - +import pytest import time from argparse import ArgumentParser import transformers -from colossalai.gptq.gptq_utils import GPTQ -from colossalai.gptq.gptq_utils.utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders -from colossalai.gptq.gptq_utils import quant -from colossalai.gptq.gptq_utils.quant import Quantizer +from auto_gptq.quantization import GPTQ +from auto_gptq.modeling._utils import find_layers, pack_model +# from auto_gptq import quant +from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear + +from auto_gptq.quantization.quantizer import Quantizer from colossalai.gptq.cai_gptq.gptq_op import CaiGPTQLinearOp import math import numpy as np -import csv +# import csv class MLinear(nn.Module): def __init__(self, infeature, outfeature): @@ -28,6 +29,7 @@ def model_quant(model, inps, dev, args): print('Starting ...') layers = [model] layers[0] = layers[0].to(dev) + dtype = next(iter(model.parameters())).dtype cache = {'i': 0} class Catcher(nn.Module): @@ -46,7 +48,7 @@ def forward(self, inp, **kwargs): pass layers[0] = layers[0].module - layers[0] = layers[0].cpu() + # layers[0] = layers[0].cpu() # outs = torch.zeros_like(inps) outs = torch.zeros(inps.shape[0], layers[0].linear.weight.shape[0]) @@ -56,13 +58,12 @@ def forward(self, inp, **kwargs): quantizers = {} for i in range(len(layers)): layer = layers[i].to(dev) - subset = find_layers(layer) gptq = {} for name in subset: gptq[name] = GPTQ(subset[name]) - gptq[name].quantizer = Quantizer() - gptq[name].quantizer.configure( args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits ) + # gptq[name].quantizer = Quantizer() + gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits) def add_batch(name): def tmp(_, inp, out): @@ -80,10 +81,13 @@ def tmp(_, inp, out): h.remove() for name in subset: print(f'Quantizing {name} in layer {i+1}/{len(layers)}...') - scale,zero,g_idx,error= gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) - quantizers['%s' % (name)] = (gptq[name].quantizer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) + scale,zero,g_idx = gptq[name].fasterquant(percdamp=args.percdamp, group_size=args.groupsize, actorder=args.act_order) + # quantizers['%s' % (name)] = (gptq[name].quantizer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) + quantizers['%s' % (name)] = (gptq[name].layer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) + gptq[name].free() for j in range(args.nsamples): + layer = layer.to(dev) outs[j] = layer(inps[j].unsqueeze(0))[0] layers[i] = layer.cpu() @@ -95,17 +99,10 @@ def tmp(_, inp, out): return quantizers + def model_pack(model, quantizers, wbits, groupsize): - layers = find_layers(model) - layers = {n: layers[n] for n in quantizers} - quant.make_quant_linear(model, quantizers, wbits, groupsize) - qlayers = find_layers(model, [quant.QuantLinear]) - print('Packing ...') - for name in qlayers: - quantizers[name], scale, zero, g_idx = quantizers[name] - qlayers[name].pack(layers[name], scale, zero, g_idx) - print('Done.') - return qlayers['linear'] + pack_model(model, quantizers, wbits, groupsize) + return model @@ -231,7 +228,7 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize qweight, qscales, qzeros = model_cai_pack(linear, quantizers, qweight, qscales, qzeros, args.wbits, args.groupsize) - batch_inps = torch.randn(1, 16384, infeature).to(torch.float16).to(torch.cuda.current_device()) + batch_inps = torch.randn(1, 4096, infeature).to(torch.float16).to(torch.cuda.current_device()) gptq_linear_time = 0 torch_linear_time = 0 @@ -246,7 +243,7 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize # torch_out = inps # print(f"torch out {torch_out}") torch_out = linear(torch_out) - torch.cuda.synchronize() + torch.cuda.synchronize() time_start = time.time() for i in range(0, benchmark_iter): @@ -254,7 +251,7 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize torch_out = act_func(inps) # torch_out = inps torch_out = linear(torch_out) - torch.cuda.synchronize() + torch.cuda.synchronize() time_end = time.time() torch_linear_time = time_end - time_start @@ -266,7 +263,7 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize torch_out = act_func(batch_inps) # torch_out = inps torch_out = linear(torch_out) - torch.cuda.synchronize() + torch.cuda.synchronize() time_end = time.time() torch_batch_linear_time = time_end - time_start @@ -283,7 +280,7 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize gptq_out = act_func(inps) # gptq_out = inps gptq_out = gptq_model(gptq_out) - torch.cuda.synchronize() + torch.cuda.synchronize() time_start = time.time() for i in range(0, benchmark_iter): @@ -291,7 +288,7 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize gptq_out = act_func(inps) # gptq_out = inps gptq_out = gptq_model(gptq_out) - torch.cuda.synchronize() + torch.cuda.synchronize() time_end = time.time() @@ -310,7 +307,7 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize gptq_out = act_func(batch_inps) # gptq_out = inps gptq_out = gptq_model(gptq_out) - torch.cuda.synchronize() + torch.cuda.synchronize() time_end = time.time() @@ -334,7 +331,7 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize act_type=0, bias = bias, qkv_fused = qkv_fused) - torch.cuda.synchronize() + torch.cuda.synchronize() print("warm up cai linear") @@ -342,6 +339,7 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize # f = open('cai_time.csv', 'w') # writer = csv.writer(f) + time_start = time.time() for i in range(0, warm_up_iter): with torch.no_grad(): @@ -352,7 +350,8 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize act_type=0, bias = bias, qkv_fused = qkv_fused) - torch.cuda.synchronize() + torch.cuda.synchronize() + time_end = time.time() cai_linear_time = time_end - time_start # print("block dim x:{}, block dim y:{}, time: {:.8f} ".format(i, j, cai_linear_time/benchmark_iter)) @@ -369,7 +368,7 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize act_type=0, bias = bias, qkv_fused = qkv_fused) - torch.cuda.synchronize() + torch.cuda.synchronize() time_end = time.time() batch_cai_linear_time = time_end - time_start diff --git a/requirements/requirements.txt b/requirements/requirements.txt index eece233e4e48..342fdb8f1124 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -13,3 +13,4 @@ safetensors flash_attn>=2.0 einops texttable +auto-gptq diff --git a/tests/test_gptq/quant_llama.py b/tests/test_gptq/quant_llama.py deleted file mode 100644 index cc6d980019db..000000000000 --- a/tests/test_gptq/quant_llama.py +++ /dev/null @@ -1,570 +0,0 @@ -import argparse -import time -import numpy as np -import torch -import torch.nn as nn -from colossalai.gptq.gptq_utils import quant -from colossalai.gptq import cai_gptq - -from colossalai.gptq.gptq_utils import GPTQ, Observer -from colossalai.gptq.gptq_utils.utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders, export_quant_table, gen_conditions -from texttable import Texttable -from colossalai.gptq import CaiInferenceConfig -from transformers import LlamaForCausalLM, LlamaTokenizer - -import csv - -def get_llama(model): - - def skip(*args, **kwargs): - pass - - torch.nn.init.kaiming_uniform_ = skip - torch.nn.init.uniform_ = skip - torch.nn.init.normal_ = skip - from transformers import LlamaForCausalLM, LlamaConfig, LlamaModel - if args.debug: - llama_kwargs= {"bos_token_id": 0, - "eos_token_id": 1, - "hidden_act": "silu", - "hidden_size": 4096, - "initializer_range": 0.02, - "intermediate_size": 11008, - "max_position_embeddings": 2048, - "max_sequence_length": 2048, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 1, - "pad_token_id": -1, - "rms_norm_eps": 1e-06, - "tie_word_embeddings": False, - "torch_dtype": "float16", - "use_cache": True, - "vocab_size": 32000 - } - configuration = LlamaConfig( **llama_kwargs - ) - model = LlamaForCausalLM(configuration) - else: - model = LlamaForCausalLM.from_pretrained(model, torch_dtype=torch.float16) - - # # LlamaForCausalLM - model.seqlen = 2048 - return model - - -@torch.no_grad() -def llama_sequential(model, dataloader, dev): - print('Starting ...') - - use_cache = model.config.use_cache - model.config.use_cache = False - layers = model.model.layers - - model.model.embed_tokens = model.model.embed_tokens.to(dev) - model.model.norm = model.model.norm.to(dev) - layers[0] = layers[0].to(dev) - - dtype = next(iter(model.parameters())).dtype - inps = torch.zeros((args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev) - cache = {'i': 0, 'attention_mask': None} - - class Catcher(nn.Module): - - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self, inp, **kwargs): - inps[cache['i']] = inp - cache['i'] += 1 - cache['attention_mask'] = kwargs['attention_mask'] - cache['position_ids'] = kwargs['position_ids'] - raise ValueError - - layers[0] = Catcher(layers[0]) - for batch in dataloader: - try: - model(batch[0].to(dev)) - except ValueError: - pass - layers[0] = layers[0].module - - layers[0] = layers[0].cpu() - model.model.embed_tokens = model.model.embed_tokens.cpu() - model.model.norm = model.model.norm.cpu() - torch.cuda.empty_cache() - - outs = torch.zeros_like(inps) - attention_mask = cache['attention_mask'] - position_ids = cache['position_ids'] - - print('Ready.') - - quantizers = {} - observer = Observer() - for i in range(len(layers)): - - print(f'Quantizing layer {i+1}/{len(layers)}..') - print('+------------------+--------------+------------+-----------+-------+') - print('| name | weight_error | fp_inp_SNR | q_inp_SNR | time |') - print('+==================+==============+============+===========+=======+') - - layer = layers[i].to(dev) - full = find_layers(layer) - if args.true_sequential: - sequential = [['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'], ['self_attn.o_proj'], ['mlp.up_proj', 'mlp.gate_proj'], ['mlp.down_proj']] - else: - sequential = [list(full.keys())] - - for names in sequential: - subset = {n: full[n] for n in names} - gptq = {} - for name in subset: - gptq[name] = GPTQ(subset[name], observe=args.observe) - gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False) - - def add_batch(name): - - def tmp(_, inp, out): - gptq[name].add_batch(inp[0].data, out.data) - - return tmp - - handles = [] - for name in subset: - handles.append(subset[name].register_forward_hook(add_batch(name))) - for j in range(args.nsamples): - outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] - for h in handles: - h.remove() - - for name in subset: - scale, zero, g_idx, error = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, name=name) - quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), args.wbits, args.groupsize) - - if args.observe: - observer.submit(name=name, layerid=i, gptq=gptq[name], error=error) - else: - gptq[name].free() - - for j in range(args.nsamples): - outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] - - layers[i] = layer.cpu() - del layer - del gptq - torch.cuda.empty_cache() - - inps, outs = outs, inps - print('+------------------+--------------+------------+-----------+-------+') - print('\n') - - if args.observe: - observer.print() - conditions = gen_conditions(args.wbits, args.groupsize) - for item in observer.items(): - name = item[0] - layerid = item[1] - gptq = item[2]['gptq'] - error = item[2]['error'] - target = error / 2 - - table = Texttable() - table.header(['wbits', 'groupsize', 'error']) - table.set_cols_dtype(['i', 'i', 'f']) - table.add_row([args.wbits, args.groupsize, error]) - - print('Optimizing {} {} ..'.format(name, layerid)) - for wbits, groupsize in conditions: - - if error < target: - # if error dropped 50%, skip - break - - gptq.quantizer.configure(wbits, perchannel=True, sym=args.sym, mse=False) - - scale, zero, g_idx, error = gptq.fasterquant(percdamp=args.percdamp, groupsize=groupsize, actorder=args.act_order, name=name) - - table.add_row([wbits, groupsize, error]) - quantizers['model.layers.%d.%s' % (layerid, name)] = (gptq.quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu(), wbits, groupsize) - - print(table.draw()) - print('\n') - gptq.layer.to('cpu') - gptq.free() - - model.config.use_cache = use_cache - - return quantizers - - -# TODO: perform packing on GPU -def cai_llama_pack(model, quantizers, wbits, groupsize): - layers = find_layers(model) - # print(f"model {model}") - # print(f"layers {layers}") - - layers = {n: layers[n] for n in quantizers} - # print(f"quantizers {quantizers}") - cai_gptq.make_cai_quant_linear(model, quantizers, wbits, groupsize) - qlayers = find_layers(model, [cai_gptq.CaiQuantLinear]) - print('Packing ...') - for name in qlayers: - print(name) - quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] - qlayers[name].pack(layers[name], scale, zero, g_idx) - print('Done.') - return model - -def gptq_llama_pack(model, quantizers, wbits, groupsize): - layers = find_layers(model) - # print(f"model {model}") - # print(f"layers {layers}") - - layers = {n: layers[n] for n in quantizers} - # print(f"quantizers {quantizers}") - quant.make_quant_linear(model, quantizers, wbits, groupsize) - qlayers = find_layers(model, [quant.QuantLinear]) - print('Packing ...') - for name in qlayers: - print(name) - quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] - qlayers[name].pack(layers[name], scale, zero, g_idx) - print('Done.') - return model - - -def cai_load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True): - from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils - config = LlamaConfig.from_pretrained(model) - - def noop(*args, **kwargs): - pass - - torch.nn.init.kaiming_uniform_ = noop - torch.nn.init.uniform_ = noop - torch.nn.init.normal_ = noop - - torch.set_default_dtype(torch.half) - modeling_utils._init_weights = False - torch.set_default_dtype(torch.half) - model = LlamaForCausalLM(config) - torch.set_default_dtype(torch.float) - if eval: - model = model.eval() - layers = find_layers(model) - for name in ['lm_head']: - if name in layers: - del layers[name] - cai_gptq.make_cai_quant_linear(model, layers, wbits, groupsize) - - del layers - - print('Loading model ...') - if checkpoint.endswith('.safetensors'): - from safetensors.torch import load_file as safe_load - model.load_state_dict(safe_load(checkpoint)) - else: - model.load_state_dict(torch.load(checkpoint)) - - print('Done.') - - return model - - -def gptq_load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True): - from transformers import LlamaConfig, LlamaForCausalLM, modeling_utils - config = LlamaConfig.from_pretrained(model) - - def noop(*args, **kwargs): - pass - - torch.nn.init.kaiming_uniform_ = noop - torch.nn.init.uniform_ = noop - torch.nn.init.normal_ = noop - - torch.set_default_dtype(torch.half) - modeling_utils._init_weights = False - torch.set_default_dtype(torch.half) - model = LlamaForCausalLM(config) - torch.set_default_dtype(torch.float) - if eval: - model = model.eval() - layers = find_layers(model) - for name in ['lm_head']: - if name in layers: - del layers[name] - quant.make_quant_linear(model, layers, wbits, groupsize) - - del layers - - print('Loading model ...') - if checkpoint.endswith('.safetensors'): - from safetensors.torch import load_file as safe_load - model.load_state_dict(safe_load(checkpoint)) - else: - model.load_state_dict(torch.load(checkpoint)) - - print('Done.') - - return model - -all_perfs = [] -now_perf=[] - -def print_perf_stats(latency_set, config, warmup=3): - global now_perf - # trim warmup queries - latency_set = list(latency_set) - latency_set = latency_set[warmup:] - count = len(latency_set) - - if count > 0: - latency_set.sort() - avg = sum(latency_set) / count - num_layers = getattr(config, "num_layers", config.num_hidden_layers) - num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 - num_bytes = 2 - # if args.dtype == "float16": - # num_bytes = 2 - # elif args.dtype == "float32": - # num_bytes = 4 - # else: - # num_bytes = 1 - print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) - print("Avg BW: {0:8.2f} GB/s".format(1/avg * num_parameters * num_bytes / 1e9)) - print("Avg flops: {0:8.2f} TFlops/s".format(1/avg * num_parameters * num_bytes * args.batch_size / 1e12)) - print("Alloc GPU Mem: {0:8.2f} GB".format(torch.cuda.memory_allocated() / 1e9)) - print("Max alloc GPU Mem: {0:8.2f} GB".format(torch.cuda.max_memory_allocated()/1e9)) - row = [args.batch_size, args.input_len, args.max_new_tokens, "{0:8.2f}".format(avg * 1000), - "{0:8.2f}".format(torch.cuda.memory_allocated() / 1e9), - "{0:8.2f}".format(torch.cuda.max_memory_allocated()/1e9)] - with open('./{}_profile.csv'.format(args.model_type), 'a', encoding='UTF8') as f: - # create the csv writer - writer = csv.writer(f) - - # write a row to the csv file - writer.writerow(row) - - now_perf.append("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) - now_perf.append("Alloc GPU Mem: {0:8.2f} GB".format(torch.cuda.memory_allocated() / 1e9)) - now_perf.append("Max alloc GPU Mem: {0:8.2f} GB".format(torch.cuda.max_memory_allocated()/1e9)) - - all_perfs.append(now_perf) - now_perf = [] - -def benchmark(model): - - input_tokens = {"input_ids":torch.randint(1, 1000, (args.batch_size, args.input_len), device=DEV), - "attention_mask":torch.ones((args.batch_size, args.input_len), device=DEV)} - torch.cuda.synchronize() - iters = 10 if args.benchmark else 2 #warmup - print(f"model config {model.config}") - - times = [] - warmup=3 - prof_flag = 0 - generate_kwargs = dict(max_new_tokens=args.max_new_tokens, do_sample=False) - torch.cuda.reset_peak_memory_stats() - for i in range(iters): - if i >= warmup: - prof_flag=1 - torch.cuda.synchronize() - start = time.time() - outputs = model.generate(**input_tokens, - **generate_kwargs) - torch.cuda.synchronize() - end = time.time() - times.append(end - start) - print("outpus shape: ", outputs.shape) - print(args) - print("input batch, input len, out len: ",args.batch_size, args.input_len, args.max_new_tokens) - # if args.local_rank == 0: - now_perf.extend(["input batch, input len, out len: ",args.batch_size, args.input_len, args.max_new_tokens]) - print_perf_stats(map(lambda t: t / args.max_new_tokens, times), model.config) - -def test(model_1, model_2): - # input_tokens = {"input_ids":torch.randint(1, 1000, (args.batch_size, args.input_len), device=DEV), - # "attention_mask":torch.ones((args.batch_size, args.input_len), device=DEV)} - generate_kwargs = dict(max_new_tokens=args.max_new_tokens, do_sample=False) - - - tokenizer = LlamaTokenizer.from_pretrained(args.model) - tokenizer.pad_token_id = tokenizer.unk_token_id - - text = "how is weather today? I want to know the weather of beijing. " - text = "how are you?" - - inputs = [text] - input_tokens = tokenizer.batch_encode_plus(inputs, padding = True, return_tensors="pt") - - input_len = 0 - for t in input_tokens: - if torch.is_tensor(input_tokens[t]): - input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) - # print(input_tokens[t].shape) - input_len = input_tokens[t].shape[1] - - outputs_1 = model_1.generate(**input_tokens, - **generate_kwargs) - print("model 1 done") - out_1 = tokenizer.batch_decode(outputs_1) - - print("decode out:", out_1) - if model_2 is None: - return - outputs_2 = model_2.generate(**input_tokens, - **generate_kwargs) - print("model 2 done") - - out_2 = tokenizer.batch_decode(outputs_2) - - ret = torch.allclose(outputs_1, outputs_2) - print("allclose is ", ret) - - print("decode out:", out_2) - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - - parser.add_argument('model', type=str, help='llama model to load') - parser.add_argument('dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.') - parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.') - parser.add_argument('--nsamples', type=int, default=1, help='Number of calibration data samples.') - parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.') - parser.add_argument('--nearest', action='store_true', help='Whether to run the RTN baseline.') - parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.') - parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.') - parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.') - parser.add_argument('--save', type=str, default='', help='Save quantized checkpoint under this name.') - parser.add_argument('--save_safetensors', type=str, default='', help='Save quantized `.safetensors` checkpoint under this name.') - parser.add_argument('--load', type=str, default='', help='Load quantized model.') - parser.add_argument('--benchmark', action='store_true', help='Number of tokens to use for benchmarking.') - parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.') - parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.') - parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic') - parser.add_argument('--true-sequential', action='store_true', help='Whether to run in true sequential model.') - parser.add_argument('--layers-dist', type=str, default='', help='Distribution of layers across GPUs. e.g. 2:1:1 for 2 layers on GPU 0, 1 layer on GPU 1, and 1 layer on GPU 2. Any remaining layers will be assigned to your last GPU.') - parser.add_argument('--observe', - action='store_true', - help='Auto upgrade layer precision to higher precision, for example int2 to int4, groupsize 128 to 64. \ - When this feature enabled, `--save` or `--save_safetensors` would be disable.') - parser.add_argument('--quant-directory', type=str, default=None, help='Specify the directory for export quantization parameters to toml format. `None` means no export by default.') - parser.add_argument('--max_new_tokens', type=int, default=32, help='Max new tokens to generate.') - parser.add_argument('--batch_size', type=int, default=1, help='Batch size to generate.') - parser.add_argument('--input_len', type=int, default=128, help='Batch size to generate.') - parser.add_argument('--model_type', type=str, choices=['cai', 'gptq', 'torch'], default='torch', help='Batch size to generate.') - parser.add_argument('--debug', action='store_true', help='Whether to debug or not') - - args = parser.parse_args() - - model_packed = False - if type(args.load) is not str: - args.load = args.load.as_posix() - - if args.load: - if args.model_type == "gptq": - model = gptq_load_quant(args.model, args.load, args.wbits, args.groupsize) - elif args.model_type == "cai": - model = cai_load_quant(args.model, args.load, args.wbits, args.groupsize) - else: - model = get_llama(args.model) - model.half() - - if not args.load and args.wbits < 16 and not args.nearest and args.model_type in ['cai', 'gptq']: - dataloader, testloader = get_loaders(args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen) - tick = time.time() - quantizers = llama_sequential(model, dataloader, DEV) - if args.model_type == "cai": - cai_llama_pack(model, quantizers, args.wbits, args.groupsize) - elif args.model_type == "gptq": - gptq_llama_pack(model, quantizers, args.wbits, args.groupsize) - model_packed = True - print(time.time() - tick) - - - if args.quant_directory is not None: - export_quant_table(quantizers, args.quant_directory) - - if not args.observe and args.save and args.model_type in ['cai', 'gptq']: - if not model_packed: - llama_pack(model, quantizers, args.wbits, args.groupsize) - model_packed = True - torch.save(model.state_dict(), args.save) - - if not args.observe and args.save_safetensors and args.model_type in ['cai', 'gptq']: - if not model_packed: - llama_pack(model, quantizers, args.wbits, args.groupsize) - from safetensors.torch import save_file as safe_save - state_dict = model.state_dict() - state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} - safe_save(state_dict, args.save_safetensors) - - if args.benchmark: - model = model.to(DEV) - print(f"model config {model.config.num_key_value_heads}") - - if args.model_type == "cai": - cai_inf_config = CaiInferenceConfig(fp16=True, - device=torch.cuda.current_device(), - gptq=True, - gptq_group_size=128, - gptq_quant_bits=4) - model = convert_to_ds_model(model, cai_inf_config) - model.cuda().to(torch.cuda.current_device()) - benchmark(model) - - - # torch_model = get_llama(args.model) - # torch_model.half() - # torch_model = torch_model.to(DEV) - - # gptq_model = gptq_load_quant(args.model, "llama7b-4bit-128g-gptq-nao.pt", args.wbits, args.groupsize) - # gptq_model = gptq_model.to(DEV) - - # model = cai_load_quant(args.model, args.load, args.wbits, args.groupsize) - # model = model.to(DEV) - - - # test(torch_model, model) - # test(gptq_model, None) - - # print("torch_model ", torch_model) - # print("gptq_model ", gptq_model) - # print("cai_model ", model) - # torch_qkv_out = torch_model.model.layers[0].self_attn.qkv_out - # cai_qkv_out = model.model.layers[0].self_attn.qkv_out - # gptq_qkv_out = gptq_model.model.layers[0].self_attn.qkv_out - - # gptq_out = gptq_model.model.layers[0].self_attn.q_proj.scales - # cai_out = model.model.layers[0].self_attn.q_proj.scales - - # mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) - # max_diff = torch.max(torch.abs(cai_out - gptq_out)) - # print("cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - # for i in range(3): - # cai_out = cai_qkv_out[i] - # torch_out = torch_qkv_out[i] - # gptq_out = gptq_qkv_out[i] - # mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) - # max_diff = torch.max(torch.abs(cai_out - gptq_out)) - # print("cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - # mean_diff = torch.mean(torch.abs(torch_out - gptq_out)) - # max_diff = torch.max(torch.abs(torch_out - gptq_out)) - # print("torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - # mean_diff = torch.mean(torch.abs(torch_out - cai_out)) - # max_diff = torch.max(torch.abs(torch_out - cai_out)) - # print("torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - - # # for batch in [1, 2, 4, 8, 16, 32]: - # for batch in [1]: - # args.batch_size = batch - # # for in_len in [128, 256, 512, 1024, 2048]: - # for in_len in [1024]: - # args.input_len = in_len - # benchmark(model) - # # for info in all_perfs: - # # print(info) - # # # all_perfs = [] \ No newline at end of file diff --git a/tests/test_gptq/run_gptq.sh b/tests/test_gptq/run_gptq.sh deleted file mode 100644 index cc625478de71..000000000000 --- a/tests/test_gptq/run_gptq.sh +++ /dev/null @@ -1,19 +0,0 @@ -# OMP_NUM_THREADS=48 CUDA_VISIBLE_DEVICES=5 python quant_llama.py /data/scratch/llama-7b-hf c4 \ -# --wbits 4 --true-sequential --groupsize 128 --save ./llama7b-4bit-128g-cai-nao.pt\ -# --benchmark --model_type cai --input_len 1024 --max_new_tokens 128 --batch_size 1 - -# OMP_NUM_THREADS=48 CUDA_VISIBLE_DEVICES=5 python quant_llama.py /data/scratch/llama-7b-hf c4 \ -# --wbits 4 --true-sequential --groupsize 128 --save ./llama7b-4bit-128g-gptq-nao.pt\ -# --benchmark --model_type gptq --input_len 1024 --max_new_tokens 128 --batch_size 1 - -# OMP_NUM_THREADS=48 CUDA_VISIBLE_DEVICES=5 python quant_llama.py /data/scratch/llama-7b-hf c4 \ -# --wbits 4 --true-sequential --act-order --groupsize 128 --load ./llama7b-4bit-128g-cai-nao.pt\ -# --benchmark --model_type cai --input_len 1024 --max_new_tokens 128 --batch_size 1 - -# OMP_NUM_THREADS=48 CUDA_VISIBLE_DEVICES=5 python quant_llama.py /data/scratch/llama-7b-hf c4 \ -# --wbits 4 --true-sequential --act-order --groupsize 128 --load /llama7b-4bit-128g-gptq-nao.pt \ -# --benchmark --model_type gptq --input_len 1024 --max_new_tokens 128 --batch_size 1 - -# OMP_NUM_THREADS=48 CUDA_VISIBLE_DEVICES=4 python quant_llama.py /data/scratch/llama-13b-hf c4 \ -# --wbits 4 --true-sequential --act-order --groupsize 128 \ -# --benchmark --model_type torch --input_len 1024 --max_new_tokens 128 --batch_size 1 diff --git a/tests/test_gptq/test_linear_act_fusion.py b/tests/test_gptq/test_linear_act_fusion.py index b81388a41b43..5780f3941231 100644 --- a/tests/test_gptq/test_linear_act_fusion.py +++ b/tests/test_gptq/test_linear_act_fusion.py @@ -1,19 +1,21 @@ import torch import torch.nn as nn - +import pytest import time from argparse import ArgumentParser import transformers -from colossalai.gptq.gptq_utils import GPTQ -from colossalai.gptq.gptq_utils.utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders -from colossalai.gptq.gptq_utils import quant -from colossalai.gptq.gptq_utils.quant import Quantizer +from auto_gptq.quantization import GPTQ +from auto_gptq.modeling._utils import find_layers, pack_model +# from auto_gptq import quant +from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear + +from auto_gptq.quantization.quantizer import Quantizer from colossalai.gptq.cai_gptq.gptq_op import CaiGPTQLinearOp import math import numpy as np -import csv +# import csv class MLinear(nn.Module): def __init__(self, infeature, outfeature): @@ -28,6 +30,7 @@ def model_quant(model, inps, dev, args): print('Starting ...') layers = [model] layers[0] = layers[0].to(dev) + dtype = next(iter(model.parameters())).dtype cache = {'i': 0} class Catcher(nn.Module): @@ -46,7 +49,7 @@ def forward(self, inp, **kwargs): pass layers[0] = layers[0].module - layers[0] = layers[0].cpu() + # layers[0] = layers[0].cpu() # outs = torch.zeros_like(inps) outs = torch.zeros(inps.shape[0], layers[0].linear.weight.shape[0]) @@ -56,13 +59,12 @@ def forward(self, inp, **kwargs): quantizers = {} for i in range(len(layers)): layer = layers[i].to(dev) - subset = find_layers(layer) gptq = {} for name in subset: gptq[name] = GPTQ(subset[name]) - gptq[name].quantizer = Quantizer() - gptq[name].quantizer.configure( args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits ) + # gptq[name].quantizer = Quantizer() + gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits) def add_batch(name): def tmp(_, inp, out): @@ -80,10 +82,13 @@ def tmp(_, inp, out): h.remove() for name in subset: print(f'Quantizing {name} in layer {i+1}/{len(layers)}...') - scale,zero,g_idx,error= gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) - quantizers['%s' % (name)] = (gptq[name].quantizer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) + scale,zero,g_idx = gptq[name].fasterquant(percdamp=args.percdamp, group_size=args.groupsize, actorder=args.act_order) + # quantizers['%s' % (name)] = (gptq[name].quantizer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) + quantizers['%s' % (name)] = (gptq[name].layer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) + gptq[name].free() for j in range(args.nsamples): + layer = layer.to(dev) outs[j] = layer(inps[j].unsqueeze(0))[0] layers[i] = layer.cpu() @@ -95,17 +100,10 @@ def tmp(_, inp, out): return quantizers + def model_pack(model, quantizers, wbits, groupsize): - layers = find_layers(model) - layers = {n: layers[n] for n in quantizers} - quant.make_quant_linear(model, quantizers, wbits, groupsize) - qlayers = find_layers(model, [quant.QuantLinear]) - print('Packing ...') - for name in qlayers: - quantizers[name], scale, zero, g_idx = quantizers[name] - qlayers[name].pack(layers[name], scale, zero, g_idx) - print('Done.') - return qlayers['linear'] + pack_model(model, quantizers, wbits, groupsize) + return model @@ -193,9 +191,9 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize # print("cai pack", layers) return qweight, qscales, qzeros -if __name__ == "__main__": +def test_gptq_linear(): parser = ArgumentParser() parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.') parser.add_argument('--wbits', type=int, default=4, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.') @@ -205,6 +203,7 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize parser.add_argument('--groupsize', type=int, default=128, help='Groupsize to use for quantization; default uses full row.') parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic') args = parser.parse_args() + infeature = 5120 outfeature = 5120 @@ -228,6 +227,7 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize linear = MLinear(infeature, outfeature) linear.to(torch.cuda.current_device()) + with torch.no_grad(): linear.linear.weight.data.copy_(weight) linear.linear.bias.data.copy_(bias) @@ -237,9 +237,9 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize batch_torch_out = linear(batch_inps) torch_out = act_func(torch_out) batch_torch_out = act_func(batch_torch_out) - print("batch_torch out ", batch_torch_out) - linear.to("cpu") + + # linear.to("cuda") quantizers = model_quant(linear, inps, torch.cuda.current_device(), args) qweight, qscales, qzeros = model_cai_pack(linear, quantizers, qweight, qscales, qzeros, args.wbits, args.groupsize) gptq_model = model_pack(linear, quantizers, args.wbits, args.groupsize) @@ -283,15 +283,16 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize # cai_out = cai_out[1] # batch_cai_out = batch_cai_out[1] - a = torch.sum(qscales, 0) - print("qscales ", a) - print("orch out ", torch_out) - print("gptq out ", gptq_out) - print("cai out ", cai_out) + # a = torch.sum(qscales, 0) + # print("qscales ", a) + # print("orch out ", torch_out) + # print("gptq out ", gptq_out) + # print("cai out ", cai_out) + # # print("batch_torch out ", batch_torch_out) - print("batch_torch out ", batch_torch_out) - print("batch_gptq out ", batch_gptq_out) - print("batch_cai out ", batch_cai_out) + # print("batch_torch out ", batch_torch_out) + # print("batch_gptq out ", batch_gptq_out) + # print("batch_cai out ", batch_cai_out) mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) max_diff = torch.max(torch.abs(cai_out - gptq_out)) @@ -311,4 +312,10 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize print("batch torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) mean_diff = torch.mean(torch.abs(batch_torch_out - batch_cai_out)) max_diff = torch.max(torch.abs(batch_torch_out - batch_cai_out)) - print("batch torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) \ No newline at end of file + print("batch torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + +if __name__ == "__main__": + + + + test_gptq_linear() \ No newline at end of file From e7b01cc275a4cfd9542077d84532ce4e122ec98d Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 16 Aug 2023 21:23:19 +0800 Subject: [PATCH 05/15] rname inferance/quant --- examples/{ops/gptq => inference/quant}/linear_act_fusion_bench.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{ops/gptq => inference/quant}/linear_act_fusion_bench.py (100%) diff --git a/examples/ops/gptq/linear_act_fusion_bench.py b/examples/inference/quant/linear_act_fusion_bench.py similarity index 100% rename from examples/ops/gptq/linear_act_fusion_bench.py rename to examples/inference/quant/linear_act_fusion_bench.py From f47fe3111ae32420e86b62041524fbfc40cd44e7 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 16 Aug 2023 21:47:05 +0800 Subject: [PATCH 06/15] refactor test --- tests/test_gptq/test_linear_act_fusion.py | 81 ++++++++++------------- 1 file changed, 36 insertions(+), 45 deletions(-) diff --git a/tests/test_gptq/test_linear_act_fusion.py b/tests/test_gptq/test_linear_act_fusion.py index 5780f3941231..a91dc3c19ecb 100644 --- a/tests/test_gptq/test_linear_act_fusion.py +++ b/tests/test_gptq/test_linear_act_fusion.py @@ -1,22 +1,25 @@ - import torch import torch.nn as nn import pytest import time -from argparse import ArgumentParser - import transformers from auto_gptq.quantization import GPTQ from auto_gptq.modeling._utils import find_layers, pack_model -# from auto_gptq import quant from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear from auto_gptq.quantization.quantizer import Quantizer from colossalai.gptq.cai_gptq.gptq_op import CaiGPTQLinearOp import math import numpy as np -# import csv + +wbits=4 +trits=False +nsamples=1 +percdamp=.01 +groupsize=128 +act_order=False +sym=False class MLinear(nn.Module): def __init__(self, infeature, outfeature): super(MLinear, self).__init__() @@ -26,7 +29,7 @@ def forward(self, x): return out @torch.no_grad() -def model_quant(model, inps, dev, args): +def model_quant(model, inps, dev): print('Starting ...') layers = [model] layers[0] = layers[0].to(dev) @@ -64,7 +67,7 @@ def forward(self, inp, **kwargs): for name in subset: gptq[name] = GPTQ(subset[name]) # gptq[name].quantizer = Quantizer() - gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits) + gptq[name].quantizer.configure(wbits, perchannel=True, sym=sym, mse=False, trits=trits) def add_batch(name): def tmp(_, inp, out): @@ -75,19 +78,19 @@ def tmp(_, inp, out): for name in subset: handles.append(subset[name].register_forward_hook(add_batch(name))) - for j in range(args.nsamples): + for j in range(nsamples): outs[j] = layer(inps[j].unsqueeze(0))[0] for h in handles: h.remove() for name in subset: print(f'Quantizing {name} in layer {i+1}/{len(layers)}...') - scale,zero,g_idx = gptq[name].fasterquant(percdamp=args.percdamp, group_size=args.groupsize, actorder=args.act_order) + scale,zero,g_idx = gptq[name].fasterquant(percdamp=percdamp, group_size=groupsize, actorder=act_order) # quantizers['%s' % (name)] = (gptq[name].quantizer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) quantizers['%s' % (name)] = (gptq[name].layer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) gptq[name].free() - for j in range(args.nsamples): + for j in range(nsamples): layer = layer.to(dev) outs[j] = layer(inps[j].unsqueeze(0))[0] @@ -106,7 +109,6 @@ def model_pack(model, quantizers, wbits, groupsize): return model - def cai_linear_pack(linear, scales, zeros, out_qweight, out_qscales, out_qzeros, qg_idx, infeatures, groupsize, bits): @@ -194,15 +196,6 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize def test_gptq_linear(): - parser = ArgumentParser() - parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.') - parser.add_argument('--wbits', type=int, default=4, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.') - parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.') - parser.add_argument('--nsamples', type=int, default=1, help='Number of calibration data samples.') - parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.') - parser.add_argument('--groupsize', type=int, default=128, help='Groupsize to use for quantization; default uses full row.') - parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic') - args = parser.parse_args() infeature = 5120 outfeature = 5120 @@ -216,12 +209,10 @@ def test_gptq_linear(): # ptype = torch.int32 qweight = torch.zeros(infeature//wn, outfeature, dtype=ptype, device=torch.cuda.current_device()).contiguous() - qscales = torch.zeros(infeature//args.groupsize, outfeature, dtype=torch.float16, device=torch.cuda.current_device()).contiguous() - qzeros = torch.zeros(infeature//args.groupsize, outfeature//wn, dtype=ptype, device=torch.cuda.current_device()).contiguous() + qscales = torch.zeros(infeature//groupsize, outfeature, dtype=torch.float16, device=torch.cuda.current_device()).contiguous() + qzeros = torch.zeros(infeature//groupsize, outfeature//wn, dtype=ptype, device=torch.cuda.current_device()).contiguous() - # print(linear.linear.weight) act_func = nn.SiLU() - inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device()) batch_inps = torch.randn(1, 4096, infeature).to(torch.float16).to(torch.cuda.current_device()) @@ -240,14 +231,14 @@ def test_gptq_linear(): # linear.to("cuda") - quantizers = model_quant(linear, inps, torch.cuda.current_device(), args) - qweight, qscales, qzeros = model_cai_pack(linear, quantizers, qweight, qscales, qzeros, args.wbits, args.groupsize) - gptq_model = model_pack(linear, quantizers, args.wbits, args.groupsize) + quantizers = model_quant(linear, inps, torch.cuda.current_device()) + qweight, qscales, qzeros = model_cai_pack(linear, quantizers, qweight, qscales, qzeros, wbits, groupsize) + gptq_model = model_pack(linear, quantizers, wbits, groupsize) gptq_model.to(torch.cuda.current_device()) # gptq_model = linear - cai_linear = CaiGPTQLinearOp(args.groupsize, args.wbits) + cai_linear = CaiGPTQLinearOp(groupsize, wbits) # qweight = torch.cat((qweight, qweight, qweight), dim=0).contiguous() @@ -255,7 +246,6 @@ def test_gptq_linear(): # qzeros = torch.cat((qzeros, qzeros, qzeros), dim=0).contiguous() # bias = torch.cat((bias, bias, bias), dim=0).contiguous() qkv_fused=False - # inps[:, :, 256:] = 0 with torch.no_grad(): gptq_out = gptq_model(inps) @@ -296,26 +286,27 @@ def test_gptq_linear(): mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) max_diff = torch.max(torch.abs(cai_out - gptq_out)) - print("cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - mean_diff = torch.mean(torch.abs(torch_out - gptq_out)) - max_diff = torch.max(torch.abs(torch_out - gptq_out)) - print("torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - mean_diff = torch.mean(torch.abs(torch_out - cai_out)) - max_diff = torch.max(torch.abs(torch_out - cai_out)) - print("torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + assert mean_diff < 1 and max_diff < 1 + + # print("cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + # mean_diff = torch.mean(torch.abs(torch_out - gptq_out)) + # max_diff = torch.max(torch.abs(torch_out - gptq_out)) + # print("torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + # mean_diff = torch.mean(torch.abs(torch_out - cai_out)) + # max_diff = torch.max(torch.abs(torch_out - cai_out)) + # print("torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) mean_diff = torch.mean(torch.abs(batch_cai_out - batch_gptq_out)) max_diff = torch.max(torch.abs(batch_cai_out - batch_gptq_out)) - print("batch cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - mean_diff = torch.mean(torch.abs(batch_torch_out - batch_gptq_out)) - max_diff = torch.max(torch.abs(batch_torch_out - batch_gptq_out)) - print("batch torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - mean_diff = torch.mean(torch.abs(batch_torch_out - batch_cai_out)) - max_diff = torch.max(torch.abs(batch_torch_out - batch_cai_out)) - print("batch torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + assert mean_diff < 1 and max_diff < 1 + # print("batch cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + # mean_diff = torch.mean(torch.abs(batch_torch_out - batch_gptq_out)) + # max_diff = torch.max(torch.abs(batch_torch_out - batch_gptq_out)) + # print("batch torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) + # mean_diff = torch.mean(torch.abs(batch_torch_out - batch_cai_out)) + # max_diff = torch.max(torch.abs(batch_torch_out - batch_cai_out)) + # print("batch torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) if __name__ == "__main__": - - test_gptq_linear() \ No newline at end of file From 17a22e79c068fadd9ef00fc865fdf1bcc30380a5 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 16 Aug 2023 22:04:42 +0800 Subject: [PATCH 07/15] add auto-gptq as an option --- colossalai/gptq/__init__.py | 11 +++++++++++ colossalai/gptq/cai_gptq/__init__.py | 16 ++++++++++++++-- .../inference/quant/linear_act_fusion_bench.py | 4 +--- tests/test_gptq/test_linear_act_fusion.py | 5 +---- 4 files changed, 27 insertions(+), 9 deletions(-) diff --git a/colossalai/gptq/__init__.py b/colossalai/gptq/__init__.py index b28b04f64312..00d3034e7418 100644 --- a/colossalai/gptq/__init__.py +++ b/colossalai/gptq/__init__.py @@ -1,3 +1,14 @@ +HAS_AUTO_GPTQ = False +try: + import auto_gptq + HAS_AUTO_GPTQ = True +except ImportError: + warnings.warn('please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ') + HAS_AUTO_GPTQ = False + +if HAS_AUTO_GPTQ: + from .cai_gptq import (gptq_fused_linear_triton, make_cai_quant_linear, + CaiQuantLinear, CaiGPTQLinearOp) diff --git a/colossalai/gptq/cai_gptq/__init__.py b/colossalai/gptq/cai_gptq/__init__.py index 2da4309cca0c..11874c29c8b0 100644 --- a/colossalai/gptq/cai_gptq/__init__.py +++ b/colossalai/gptq/cai_gptq/__init__.py @@ -1,2 +1,14 @@ -from .gptq_triton import gptq_fused_linear_triton -from .cai_quant_linear import make_cai_quant_linear, CaiQuantLinear + +HAS_AUTO_GPTQ = False + +try: + import auto_gptq + HAS_AUTO_GPTQ = True +except ImportError: + warnings.warn('please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ') + HAS_AUTO_GPTQ = False + +if HAS_AUTO_GPTQ: + from .gptq_triton import gptq_fused_linear_triton + from .cai_quant_linear import make_cai_quant_linear, CaiQuantLinear + from .gptq_op import CaiGPTQLinearOp diff --git a/examples/inference/quant/linear_act_fusion_bench.py b/examples/inference/quant/linear_act_fusion_bench.py index c4ecf5cb5e18..cdbf9aeef68f 100644 --- a/examples/inference/quant/linear_act_fusion_bench.py +++ b/examples/inference/quant/linear_act_fusion_bench.py @@ -7,14 +7,12 @@ import transformers from auto_gptq.quantization import GPTQ from auto_gptq.modeling._utils import find_layers, pack_model -# from auto_gptq import quant from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear from auto_gptq.quantization.quantizer import Quantizer -from colossalai.gptq.cai_gptq.gptq_op import CaiGPTQLinearOp +from colossalai.gptq import CaiGPTQLinearOp import math import numpy as np -# import csv class MLinear(nn.Module): def __init__(self, infeature, outfeature): diff --git a/tests/test_gptq/test_linear_act_fusion.py b/tests/test_gptq/test_linear_act_fusion.py index a91dc3c19ecb..5d489acac547 100644 --- a/tests/test_gptq/test_linear_act_fusion.py +++ b/tests/test_gptq/test_linear_act_fusion.py @@ -8,7 +8,7 @@ from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear from auto_gptq.quantization.quantizer import Quantizer -from colossalai.gptq.cai_gptq.gptq_op import CaiGPTQLinearOp +from colossalai.gptq import CaiGPTQLinearOp import math import numpy as np @@ -52,9 +52,6 @@ def forward(self, inp, **kwargs): pass layers[0] = layers[0].module - # layers[0] = layers[0].cpu() - - # outs = torch.zeros_like(inps) outs = torch.zeros(inps.shape[0], layers[0].linear.weight.shape[0]) print('Ready.') From 370aac508c202b28919fc21d2544c6ead63d4c2a Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 16 Aug 2023 22:07:36 +0800 Subject: [PATCH 08/15] reset requirements --- requirements/requirements-test.txt | 1 + requirements/requirements.txt | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index e65271621ddd..5fee741ed322 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -17,3 +17,4 @@ requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggi SentencePiece ninja flash_attn>=2.0 +auto-gptq diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 342fdb8f1124..65eecce2c34f 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -12,5 +12,3 @@ torch>=1.11 safetensors flash_attn>=2.0 einops -texttable -auto-gptq From 6ce0b817f02ac77888ac4a7eafbd4c5747b02034 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 17 Aug 2023 11:03:19 +0800 Subject: [PATCH 09/15] change assert and check auto-gptq --- colossalai/gptq/__init__.py | 9 +-------- colossalai/gptq/cai_gptq/__init__.py | 3 ++- tests/test_gptq/test_linear_act_fusion.py | 12 ++++++------ 3 files changed, 9 insertions(+), 15 deletions(-) diff --git a/colossalai/gptq/__init__.py b/colossalai/gptq/__init__.py index 00d3034e7418..0e0ee5152138 100644 --- a/colossalai/gptq/__init__.py +++ b/colossalai/gptq/__init__.py @@ -1,11 +1,4 @@ -HAS_AUTO_GPTQ = False - -try: - import auto_gptq - HAS_AUTO_GPTQ = True -except ImportError: - warnings.warn('please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ') - HAS_AUTO_GPTQ = False +from .cai_gptq import HAS_AUTO_GPTQ if HAS_AUTO_GPTQ: from .cai_gptq import (gptq_fused_linear_triton, make_cai_quant_linear, diff --git a/colossalai/gptq/cai_gptq/__init__.py b/colossalai/gptq/cai_gptq/__init__.py index 11874c29c8b0..d49a3cc2c3ad 100644 --- a/colossalai/gptq/cai_gptq/__init__.py +++ b/colossalai/gptq/cai_gptq/__init__.py @@ -7,8 +7,9 @@ except ImportError: warnings.warn('please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ') HAS_AUTO_GPTQ = False - + if HAS_AUTO_GPTQ: + from .gptq_triton import gptq_fused_linear_triton from .cai_quant_linear import make_cai_quant_linear, CaiQuantLinear from .gptq_op import CaiGPTQLinearOp diff --git a/tests/test_gptq/test_linear_act_fusion.py b/tests/test_gptq/test_linear_act_fusion.py index 5d489acac547..6dcae74c4304 100644 --- a/tests/test_gptq/test_linear_act_fusion.py +++ b/tests/test_gptq/test_linear_act_fusion.py @@ -281,10 +281,11 @@ def test_gptq_linear(): # print("batch_gptq out ", batch_gptq_out) # print("batch_cai out ", batch_cai_out) - mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) - max_diff = torch.max(torch.abs(cai_out - gptq_out)) - assert mean_diff < 1 and max_diff < 1 + assert torch.allclose(cai_out, gptq_out, rtol=1e-01, atol=1e-02) + assert torch.allclose(batch_cai_out, batch_gptq_out, rtol=1e-01, atol=1e-02) + # mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) + # max_diff = torch.max(torch.abs(cai_out - gptq_out)) # print("cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) # mean_diff = torch.mean(torch.abs(torch_out - gptq_out)) # max_diff = torch.max(torch.abs(torch_out - gptq_out)) @@ -293,9 +294,8 @@ def test_gptq_linear(): # max_diff = torch.max(torch.abs(torch_out - cai_out)) # print("torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - mean_diff = torch.mean(torch.abs(batch_cai_out - batch_gptq_out)) - max_diff = torch.max(torch.abs(batch_cai_out - batch_gptq_out)) - assert mean_diff < 1 and max_diff < 1 + # mean_diff = torch.mean(torch.abs(batch_cai_out - batch_gptq_out)) + # max_diff = torch.max(torch.abs(batch_cai_out - batch_gptq_out)) # print("batch cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) # mean_diff = torch.mean(torch.abs(batch_torch_out - batch_gptq_out)) # max_diff = torch.max(torch.abs(batch_torch_out - batch_gptq_out)) From cdb31810515118686990ec092a48c812a34494e6 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 17 Aug 2023 11:22:27 +0800 Subject: [PATCH 10/15] add import warnings --- colossalai/gptq/cai_gptq/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/colossalai/gptq/cai_gptq/__init__.py b/colossalai/gptq/cai_gptq/__init__.py index d49a3cc2c3ad..68addb8fb2f5 100644 --- a/colossalai/gptq/cai_gptq/__init__.py +++ b/colossalai/gptq/cai_gptq/__init__.py @@ -1,15 +1,14 @@ +import warnings HAS_AUTO_GPTQ = False - try: import auto_gptq HAS_AUTO_GPTQ = True except ImportError: warnings.warn('please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ') HAS_AUTO_GPTQ = False - -if HAS_AUTO_GPTQ: +if HAS_AUTO_GPTQ: from .gptq_triton import gptq_fused_linear_triton from .cai_quant_linear import make_cai_quant_linear, CaiQuantLinear from .gptq_op import CaiGPTQLinearOp From ac8ccd1a5f973caf84ebdbf684ea3c6ba84198d9 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 17 Aug 2023 16:22:52 +0800 Subject: [PATCH 11/15] change test flash attn version --- requirements/requirements-test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 5fee741ed322..e1430edc38fb 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -16,5 +16,5 @@ triton==2.0.0.dev20221202 requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 SentencePiece ninja -flash_attn>=2.0 +flash_attn==2.0.5 auto-gptq From 0e0bcc3db3d71faae4c386d1d8aebac38671eb5f Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 17 Aug 2023 19:24:29 +0800 Subject: [PATCH 12/15] remove example --- .../quant/linear_act_fusion_bench.py | 380 ------------------ 1 file changed, 380 deletions(-) delete mode 100644 examples/inference/quant/linear_act_fusion_bench.py diff --git a/examples/inference/quant/linear_act_fusion_bench.py b/examples/inference/quant/linear_act_fusion_bench.py deleted file mode 100644 index cdbf9aeef68f..000000000000 --- a/examples/inference/quant/linear_act_fusion_bench.py +++ /dev/null @@ -1,380 +0,0 @@ -import torch -import torch.nn as nn -import pytest -import time -from argparse import ArgumentParser - -import transformers -from auto_gptq.quantization import GPTQ -from auto_gptq.modeling._utils import find_layers, pack_model -from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear - -from auto_gptq.quantization.quantizer import Quantizer -from colossalai.gptq import CaiGPTQLinearOp -import math -import numpy as np - -class MLinear(nn.Module): - def __init__(self, infeature, outfeature): - super(MLinear, self).__init__() - self.linear = torch.nn.Linear(infeature, outfeature, dtype=torch.float16) - def forward(self, x): - out = self.linear(x) - return out - -@torch.no_grad() -def model_quant(model, inps, dev, args): - print('Starting ...') - layers = [model] - layers[0] = layers[0].to(dev) - - dtype = next(iter(model.parameters())).dtype - cache = {'i': 0} - class Catcher(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - def forward(self, inp, **kwargs): - inps[cache['i']] = inp - cache['i'] += 1 - raise ValueError - layers[0] = Catcher(layers[0]) - # for batch in inps: - try: - model(inps.to(dev)) - except ValueError: - pass - layers[0] = layers[0].module - - # layers[0] = layers[0].cpu() - - # outs = torch.zeros_like(inps) - outs = torch.zeros(inps.shape[0], layers[0].linear.weight.shape[0]) - - print('Ready.') - - quantizers = {} - for i in range(len(layers)): - layer = layers[i].to(dev) - subset = find_layers(layer) - gptq = {} - for name in subset: - gptq[name] = GPTQ(subset[name]) - # gptq[name].quantizer = Quantizer() - gptq[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False, trits=args.trits) - - def add_batch(name): - def tmp(_, inp, out): - gptq[name].add_batch(inp[0].data, out.data) - return tmp - - handles = [] - for name in subset: - handles.append(subset[name].register_forward_hook(add_batch(name))) - - for j in range(args.nsamples): - outs[j] = layer(inps[j].unsqueeze(0))[0] - - for h in handles: - h.remove() - for name in subset: - print(f'Quantizing {name} in layer {i+1}/{len(layers)}...') - scale,zero,g_idx = gptq[name].fasterquant(percdamp=args.percdamp, group_size=args.groupsize, actorder=args.act_order) - # quantizers['%s' % (name)] = (gptq[name].quantizer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) - quantizers['%s' % (name)] = (gptq[name].layer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) - - gptq[name].free() - for j in range(args.nsamples): - layer = layer.to(dev) - outs[j] = layer(inps[j].unsqueeze(0))[0] - - layers[i] = layer.cpu() - del layer - del gptq - torch.cuda.empty_cache() - - inps, outs = outs, inps - - return quantizers - - -def model_pack(model, quantizers, wbits, groupsize): - pack_model(model, quantizers, wbits, groupsize) - return model - - - -def cai_linear_pack(linear, scales, zeros, - out_qweight, out_qscales, out_qzeros, qg_idx, - infeatures, groupsize, bits): - g_idx = qg_idx.clone() if qg_idx is not None else torch.tensor([i // groupsize for i in range(infeatures)], dtype=torch.int32) - - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() - scale_zeros = zeros * scales - half_scales = scales.clone().half() - # print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape) - - out_qscales.data.copy_(scales) - - wn = 16 - pbits = 64 - ptype = torch.int64 - unsign_type = np.uint64 - sign_type = np.int64 - - # wn = 8 - # pbits = 32 - # ptype = torch.int32 - # unsign_type = np.uint32 - # sign_type = np.int32 - - intweight = [] - for idx in range(infeatures): - intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, None]) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(unsign_type) - qweight = np.zeros((intweight.shape[0] // pbits * bits, intweight.shape[1]), dtype=unsign_type) - - i = 0 - row = 0 - # print("weight shape ", intweight.shape, qweight.shape, out_qweight.shape, bits) - # print("weight shape ", intweight[0].shape, qweight[0].shape, out_qweight[0].shape) - # print("weight value ", intweight[0], qweight[0]) - - while row < qweight.shape[0]: - if bits in [2, 4, 8]: - for j in range(i, i + (pbits // bits)): - qweight[row] |= intweight[j] << (bits * (j - i)) - i += pbits // bits - row += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - qweight = qweight.astype(sign_type) - qweight1 = torch.from_numpy(qweight) - qweight1 = qweight1.contiguous().to("cuda") - out_qweight.data.copy_(qweight1) - - qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * bits), dtype=unsign_type) - zeros -= 1 - zeros = zeros.numpy().astype(unsign_type) - i = 0 - col = 0 - while col < qzeros.shape[1]: - if bits in [2, 4, 8]: - for j in range(i, i + (pbits // bits)): - qzeros[:, col] |= zeros[:, j] << (bits * (j - i)) - i += pbits // bits - col += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - qzeros = qzeros.astype(sign_type) - qzeros = torch.from_numpy(qzeros) - qzeros = qzeros.to("cuda") - out_qzeros.data.copy_(qzeros) - - return out_qweight, out_qscales, out_qzeros - -def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize): - layers = find_layers(model) - layers = {n: layers[n] for n in quantizers} - with torch.no_grad(): - for name in layers: - _, scale, zero, g_idx = quantizers[name] - qweight, qscales, qzeros = cai_linear_pack(layers[name], scale, zero, - qweight, qscales, qzeros, g_idx, - layers[name].weight.shape[-1], groupsize, wbits) - - # print("cai pack", layers) - return qweight, qscales, qzeros -if __name__ == "__main__": - - - parser = ArgumentParser() - parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.') - parser.add_argument('--wbits', type=int, default=4, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.') - parser.add_argument('--trits', action='store_true', help='Whether to use trits for quantization.') - parser.add_argument('--nsamples', type=int, default=1, help='Number of calibration data samples.') - parser.add_argument('--percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.') - parser.add_argument('--groupsize', type=int, default=128, help='Groupsize to use for quantization; default uses full row.') - parser.add_argument('--act-order', action='store_true', help='Whether to apply the activation order GPTQ heuristic') - args = parser.parse_args() - infeature = 8192 - outfeature = 8192 - - weight = torch.randn(outfeature, infeature).to(torch.float16).to(torch.cuda.current_device()) - bias = torch.zeros(outfeature).to(torch.float16).to(torch.cuda.current_device()) - wn = 16 - ptype = torch.int64 - - # wn = 8 - # ptype = torch.int32 - - qweight = torch.zeros(infeature//wn, outfeature, dtype=ptype, device=torch.cuda.current_device()).contiguous() - qscales = torch.zeros(infeature//args.groupsize, outfeature, dtype=torch.float16, device=torch.cuda.current_device()).contiguous() - qzeros = torch.zeros(infeature//args.groupsize, outfeature//wn, dtype=ptype, device=torch.cuda.current_device()).contiguous() - - - linear = MLinear(infeature, outfeature) - linear.to(torch.cuda.current_device()) - with torch.no_grad(): - linear.linear.weight.data.copy_(weight) - linear.linear.bias.data.copy_(bias) - inps = torch.randn(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device()) - quantizers = model_quant(linear, inps, torch.cuda.current_device(), args) - qweight, qscales, qzeros = model_cai_pack(linear, quantizers, qweight, qscales, qzeros, args.wbits, args.groupsize) - - - batch_inps = torch.randn(1, 4096, infeature).to(torch.float16).to(torch.cuda.current_device()) - - gptq_linear_time = 0 - torch_linear_time = 0 - warm_up_iter = 2 - benchmark_iter = 100 - - act_func = nn.ReLU() - linear.to("cuda") - for i in range(0, warm_up_iter): - with torch.no_grad(): - torch_out = act_func(inps) - # torch_out = inps - # print(f"torch out {torch_out}") - torch_out = linear(torch_out) - torch.cuda.synchronize() - - time_start = time.time() - for i in range(0, benchmark_iter): - with torch.no_grad(): - torch_out = act_func(inps) - # torch_out = inps - torch_out = linear(torch_out) - torch.cuda.synchronize() - - time_end = time.time() - torch_linear_time = time_end - time_start - - - time_start = time.time() - for i in range(0, benchmark_iter): - with torch.no_grad(): - torch_out = act_func(batch_inps) - # torch_out = inps - torch_out = linear(torch_out) - torch.cuda.synchronize() - - time_end = time.time() - torch_batch_linear_time = time_end - time_start - - linear.to("cpu") - - gptq_model = model_pack(linear, quantizers, args.wbits, args.groupsize) - gptq_model.to(torch.cuda.current_device()) - - # gptq_model = linear - - for i in range(0, warm_up_iter): - with torch.no_grad(): - gptq_out = act_func(inps) - # gptq_out = inps - gptq_out = gptq_model(gptq_out) - torch.cuda.synchronize() - - time_start = time.time() - for i in range(0, benchmark_iter): - with torch.no_grad(): - gptq_out = act_func(inps) - # gptq_out = inps - gptq_out = gptq_model(gptq_out) - torch.cuda.synchronize() - - time_end = time.time() - - gptq_linear_time = time_end - time_start - - for i in range(0, warm_up_iter): - with torch.no_grad(): - gptq_out = act_func(batch_inps) - # gptq_out = inps - gptq_out = gptq_model(gptq_out) - torch.cuda.synchronize() - - time_start = time.time() - for i in range(0, benchmark_iter): - with torch.no_grad(): - gptq_out = act_func(batch_inps) - # gptq_out = inps - gptq_out = gptq_model(gptq_out) - torch.cuda.synchronize() - - time_end = time.time() - - gptq_batch_linear_time = time_end - time_start - - # qweight = torch.cat((qweight, qweight, qweight), dim=0).contiguous() - # qscales = torch.cat((qscales, qscales, qscales), dim=0).contiguous() - # qzeros = torch.cat((qzeros, qzeros, qzeros), dim=0).contiguous() - # bias = torch.cat((bias, bias, bias), dim=0).contiguous() - qkv_fused = False - - cai_linear = CaiGPTQLinearOp(args.groupsize, args.wbits) - - print("cai linear") - for i in range(0, warm_up_iter): - with torch.no_grad(): - cai_out = cai_linear(inps, - qweight, - qscales, - qzeros, - act_type=0, - bias = bias, - qkv_fused = qkv_fused) - torch.cuda.synchronize() - - - print("warm up cai linear") - - # f = open('cai_time.csv', 'w') - # writer = csv.writer(f) - - time_start = time.time() - - for i in range(0, warm_up_iter): - with torch.no_grad(): - cai_out = cai_linear(batch_inps, - qweight, - qscales, - qzeros, - act_type=0, - bias = bias, - qkv_fused = qkv_fused) - torch.cuda.synchronize() - time_end = time.time() - - cai_linear_time = time_end - time_start - # print("block dim x:{}, block dim y:{}, time: {:.8f} ".format(i, j, cai_linear_time/benchmark_iter)) - # row=[i, j, cai_linear_time/benchmark_iter] - - - time_start = time.time() - for k in range(0, benchmark_iter): - with torch.no_grad(): - cai_out = cai_linear(batch_inps, - qweight, - qscales, - qzeros, - act_type=0, - bias = bias, - qkv_fused = qkv_fused) - torch.cuda.synchronize() - time_end = time.time() - - batch_cai_linear_time = time_end - time_start - - print("torch time: {:.8f}".format(torch_linear_time/benchmark_iter)) - print("gptq time:{:.8f}".format( gptq_linear_time/benchmark_iter)) - print("cai gptq time:{:.8f}".format( cai_linear_time/benchmark_iter)) - - print("batch torch time: {:.8f}".format(torch_batch_linear_time/benchmark_iter)) - print("batch gptq time:{:.8f}".format( gptq_batch_linear_time/benchmark_iter)) - print("batch cai gptq time:{:.8f}".format( batch_cai_linear_time/benchmark_iter)) From 4ff9c99fc856d4f6313fdff25c628f09bcddb731 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Fri, 18 Aug 2023 12:30:00 +0800 Subject: [PATCH 13/15] change requirements of flash_attn --- requirements/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 65eecce2c34f..f6be6a624c70 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -10,5 +10,4 @@ contexttimer ninja torch>=1.11 safetensors -flash_attn>=2.0 einops From 16a3f00c4967b6100f855538281ebd43c26d8d94 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Mon, 21 Aug 2023 16:46:39 +0800 Subject: [PATCH 14/15] modify tests --- tests/test_gptq/test_linear_act_fusion.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_gptq/test_linear_act_fusion.py b/tests/test_gptq/test_linear_act_fusion.py index 6dcae74c4304..89e1741b4a00 100644 --- a/tests/test_gptq/test_linear_act_fusion.py +++ b/tests/test_gptq/test_linear_act_fusion.py @@ -267,7 +267,6 @@ def test_gptq_linear(): batch_gptq_out = act_func(batch_gptq_out) gptq_out = act_func(gptq_out) - # cai_out = cai_out[1] # batch_cai_out = batch_cai_out[1] # a = torch.sum(qscales, 0) @@ -306,4 +305,4 @@ def test_gptq_linear(): if __name__ == "__main__": - test_gptq_linear() \ No newline at end of file + test_gptq_linear() From d24803119f624ba746baeb9f3d1409ee240d0456 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Mon, 21 Aug 2023 17:42:09 +0800 Subject: [PATCH 15/15] [skip ci] change requirements-test --- requirements/requirements-test.txt | 2 +- tests/test_gptq/test_linear_act_fusion.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index e1430edc38fb..657bd3eb28d8 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -17,4 +17,4 @@ requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggi SentencePiece ninja flash_attn==2.0.5 -auto-gptq +#auto-gptq now not support torch1.12 diff --git a/tests/test_gptq/test_linear_act_fusion.py b/tests/test_gptq/test_linear_act_fusion.py index 89e1741b4a00..4540d990dc3a 100644 --- a/tests/test_gptq/test_linear_act_fusion.py +++ b/tests/test_gptq/test_linear_act_fusion.py @@ -283,6 +283,7 @@ def test_gptq_linear(): assert torch.allclose(cai_out, gptq_out, rtol=1e-01, atol=1e-02) assert torch.allclose(batch_cai_out, batch_gptq_out, rtol=1e-01, atol=1e-02) + # mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) # max_diff = torch.max(torch.abs(cai_out - gptq_out)) # print("cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff))