Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,53 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.

---------------- LICENSE FOR AutoGPTQ ----------------

From AutoGPTQ:

MIT License

Copyright (c) 2023 潘其威(William)

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

---------------- LICENSE FOR exllama ----------------

From exllama:

MIT License

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
7 changes: 2 additions & 5 deletions colossalai/gptq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from .cai_gptq import HAS_AUTO_GPTQ

if HAS_AUTO_GPTQ:
from .cai_gptq import (gptq_fused_linear_triton, make_cai_quant_linear,
CaiQuantLinear, CaiGPTQLinearOp)


if HAS_AUTO_GPTQ:
from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear, gptq_fused_linear_triton, make_cai_quant_linear
4 changes: 2 additions & 2 deletions colossalai/gptq/cai_gptq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@
HAS_AUTO_GPTQ = False

if HAS_AUTO_GPTQ:
from .gptq_triton import gptq_fused_linear_triton
from .cai_quant_linear import make_cai_quant_linear, CaiQuantLinear
from .cai_quant_linear import CaiQuantLinear, make_cai_quant_linear
from .gptq_op import CaiGPTQLinearOp
from .gptq_triton import gptq_fused_linear_triton
161 changes: 125 additions & 36 deletions colossalai/gptq/cai_gptq/cai_quant_linear.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ

import math
import warnings

import numpy as np
import torch
import torch.nn as nn
from .gptq_op import CaiGPTQLinearOp
import triton

from .gptq_op import CaiGPTQLinearOp

HAS_GPTQ_CUDA = False
try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True
except ImportError:
warnings.warn('CUDA gptq is not installed')
HAS_GPTQ_CUDA = False


class CaiQuantLinear(nn.Module):
max_dq_buffer_size = 1
max_inner_outer_dim = 1
max_input_len = 1
prepared_buffers = False
device_to_buffers = {
"temp_state": None,
"temp_dq": None,
}

def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
super().__init__()
Expand All @@ -18,9 +40,12 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
self.maxq = 2**self.bits - 1
self.groupsize = groupsize if groupsize != -1 else infeatures

self.register_buffer('qweight', torch.zeros((infeatures // 64 * self.bits, outfeatures), dtype=torch.int64))
self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 64 * self.bits), dtype=torch.int64))
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
self.register_buffer(
'qzeros',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
self.register_buffer('scales',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))

if bias:
Expand All @@ -30,10 +55,13 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias):

self.gptq_linear = CaiGPTQLinearOp(groupsize, bits)

self.q4 = None
self.empty_tensor = torch.empty((1, 1), device="meta")

def pack(self, linear, scales, zeros, g_idx=None):

g_idx = g_idx.clone() if g_idx is not None else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
g_idx = g_idx.clone() if g_idx is not None else torch.tensor(
[i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)

scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
Expand All @@ -44,21 +72,24 @@ def pack(self, linear, scales, zeros, g_idx=None):
if linear.bias is not None:
self.bias = linear.bias.clone().half()

wn = 16
pbits = 64
ptype = torch.int64
unsign_type = np.uint64
sign_type = np.int64
# wn = 16
# pbits = 64
# ptype = torch.int64
# unsign_type = np.uint64
# sign_type = np.int64

# wn = 8
# pbits = 32
# ptype = torch.int32
# unsign_type = np.uint32
# sign_type = np.int32
wn = 8
pbits = 32
ptype = torch.int32
unsign_type = np.uint32
sign_type = np.int32

intweight = []
for idx in range(self.infeatures):
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, None])
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:,
None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(unsign_type)
Expand All @@ -72,50 +103,108 @@ def pack(self, linear, scales, zeros, g_idx=None):

while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (pbits // self.bits)):
qweight[row] |= intweight[j] << ( self.bits * (j - i))
i += pbits // self.bits
for j in range(i, i + (pbits // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += pbits // self.bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(sign_type)
qweight1 = torch.from_numpy(qweight)
qweight1 = qweight1.contiguous() #.to("cuda")
qweight1 = qweight1.contiguous() #.to("cuda")
self.qweight.data.copy_(qweight1)

qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
zeros -= 1
zeros = zeros.numpy().astype(unsign_type)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (pbits // self.bits)):
qzeros[:, col] |= zeros[:, j] << ( self.bits * (j - i))
i += pbits // self.bits
if self.bits in [2, 4, 8]:
for j in range(i, i + (pbits // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += pbits // self.bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qzeros = qzeros.astype(sign_type)
qzeros = torch.from_numpy(qzeros)
qzeros = qzeros
self.qzeros.data.copy_(qzeros)
if torch.equal(self.g_idx, g_idx):

if torch.equal(self.g_idx.to(g_idx.device), g_idx):
self.g_idx = None
else:
self.g_idx = g_idx

CaiQuantLinear.max_dq_buffer_size = max(CaiQuantLinear.max_dq_buffer_size, self.qweight.numel() * 8)

if self.g_idx is not None:
CaiQuantLinear.max_inner_outer_dim = max(CaiQuantLinear.max_inner_outer_dim, self.infeatures,
self.outfeatures)
CaiQuantLinear.max_input_len = 4096

def prepare_buffers(self):
assert self.qweight.device.type == "cuda"
device = self.qweight.device

# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
CaiQuantLinear.device_to_buffers['temp_state'] = torch.zeros(
(CaiQuantLinear.max_input_len, CaiQuantLinear.max_inner_outer_dim), dtype=torch.float16, device=device)
CaiQuantLinear.device_to_buffers['temp_dp'] = torch.zeros((1, CaiQuantLinear.max_dq_buffer_size),
dtype=torch.float16,
device=device)

gptq_cuda.prepare_buffers(torch.device(device), CaiQuantLinear.device_to_buffers['temp_state'],
CaiQuantLinear.device_to_buffers['temp_dp'])

# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)

torch.cuda.empty_cache()

def init_q4(self):
assert self.qweight.device.type == "cuda"
self.q4_width = self.qweight.shape[1]
if self.g_idx is not None:
g_idx = self.g_idx.to("cpu")
else:
g_idx = self.empty_tensor

self.q4 = gptq_cuda.make_q4(self.qweight, self.qzeros, self.scales, g_idx, torch.cuda.current_device())
torch.cuda.synchronize()

def forward(self, x):
outshape = x.shape[:-1] + (self.outfeatures,)

if HAS_GPTQ_CUDA:
if CaiQuantLinear.prepared_buffers == False:
self.prepare_buffers()
CaiQuantLinear.prepared_buffers = True

if self.q4 is None:
self.init_q4()

x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device)
gptq_cuda.q4_matmul(x, self.q4, output)
if self.bias is not None:
output.add_(self.bias)
else:
output = self.gptq_linear(
x,
self.qweight,
self.scales,
self.qzeros,
g_idx=self.g_idx,
bias=self.bias,
)
return output.view(outshape)

cai_out = self.gptq_linear(x,
self.qweight,
self.scales,
self.qzeros,
g_idx = self.g_idx,
bias = self.bias,)
return cai_out

def make_cai_quant_linear(module, names, bits, groupsize, name=''):
if isinstance(module, CaiQuantLinear):
Expand All @@ -125,7 +214,7 @@ def make_cai_quant_linear(module, names, bits, groupsize, name=''):
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
delattr(module, attr)
setattr(module, attr, CaiQuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
setattr(module, attr,
CaiQuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
for name1, child in module.named_children():
make_cai_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)

28 changes: 20 additions & 8 deletions colossalai/gptq/cai_gptq/gptq_op.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .gptq_triton import gptq_fused_linear_triton
import torch

from .gptq_triton import gptq_fused_linear_triton


class CaiGPTQLinearOp(torch.nn.Module):

Expand All @@ -17,10 +18,10 @@ def forward(self,
weight_scales: torch.Tensor,
weight_zeros: torch.Tensor,
g_idx: torch.Tensor = None,
act_type = 0,
act_type=0,
bias: torch.Tensor = None,
residual: torch.Tensor=None,
qkv_fused = False):
residual: torch.Tensor = None,
qkv_fused=False):

add_bias = True
if bias is None:
Expand All @@ -33,12 +34,23 @@ def forward(self,
add_residual = False
x = input.view(-1, input.shape[-1])

out = gptq_fused_linear_triton(x, weight, weight_scales, weight_zeros, bias, residual,
self.bits, self.maxq, self.group_size, qkv_fused, add_bias, add_residual,
act_type=act_type, g_idx=g_idx)
out = gptq_fused_linear_triton(x,
weight,
weight_scales,
weight_zeros,
bias,
residual,
self.bits,
self.maxq,
self.group_size,
qkv_fused,
add_bias,
add_residual,
act_type=act_type,
g_idx=g_idx)
if qkv_fused:
out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1])
else:
out = out.view(input.shape[0], input.shape[1], weight.shape[-1])

return out
return out
Loading