From 35c551f50ee77441213f1375f488577f88b217aa Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 4 Oct 2023 16:38:41 +0800 Subject: [PATCH 1/5] add llama mlp for smoothquant --- .../quant/smoothquant/models/__init__.py | 2 +- .../quant/smoothquant/models/linear.py | 49 +++++++++++ .../models/{smoothquant_layer.py => llama.py} | 51 ++++++++++- tests/test_smoothquant/test_llama_mlp.py | 85 +++++++++++++++++++ 4 files changed, 184 insertions(+), 3 deletions(-) create mode 100644 colossalai/inference/quant/smoothquant/models/linear.py rename colossalai/inference/quant/smoothquant/models/{smoothquant_layer.py => llama.py} (81%) create mode 100644 tests/test_smoothquant/test_llama_mlp.py diff --git a/colossalai/inference/quant/smoothquant/models/__init__.py b/colossalai/inference/quant/smoothquant/models/__init__.py index 9382a10cffa6..77541d8610c5 100644 --- a/colossalai/inference/quant/smoothquant/models/__init__.py +++ b/colossalai/inference/quant/smoothquant/models/__init__.py @@ -9,4 +9,4 @@ ) if HAS_TORCH_INT: - from .smoothquant_layer import LLamaSmoothquantAttention + from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP diff --git a/colossalai/inference/quant/smoothquant/models/linear.py b/colossalai/inference/quant/smoothquant/models/linear.py new file mode 100644 index 000000000000..1c01c6222e7a --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/linear.py @@ -0,0 +1,49 @@ +import torch +from torch_int.functional.quantization import quantize_per_tensor_absmax + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except ImportError: + HAS_SMOOTHQUANT_CUDA = False + raise ImportError("CUDA smoothquant linear is not installed") + + +class W8A8BFP32O32LinearSiLU(torch.nn.Module): + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint(-127, 127, (self.out_features, self.in_features), dtype=torch.int8, requires_grad=False), + ) + self.register_buffer("bias", torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False)) + self.register_buffer("a", torch.tensor(alpha)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight + int8_module.bias.data.copy_(module.bias.to(torch.float)) + int8_module.a = alpha + return int8_module diff --git a/colossalai/inference/quant/smoothquant/models/smoothquant_layer.py b/colossalai/inference/quant/smoothquant/models/llama.py similarity index 81% rename from colossalai/inference/quant/smoothquant/models/smoothquant_layer.py rename to colossalai/inference/quant/smoothquant/models/llama.py index b021e604ab09..ec21289be289 100644 --- a/colossalai/inference/quant/smoothquant/models/smoothquant_layer.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -6,10 +6,12 @@ from torch import nn from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T from torch_int.nn.linear import W8A8B8O8Linear, W8A8BFP32OFP32Linear -from transformers.models.llama.modeling_llama import LlamaAttention +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP from colossalai.kernel.triton import int8_rotary_embedding_fwd +from .linear import W8A8BFP32O32LinearSiLU + class LLamaSmoothquantAttention(nn.Module): def __init__( @@ -100,7 +102,11 @@ def forward( self.rotary_output_scale, ) int8_rotary_embedding_fwd( - key_states.view(-1, self.num_heads, self.head_dim), cos, sin, self.k_output_scale, self.rotary_output_scale + key_states.view(-1, self.num_heads, self.head_dim), + cos, + sin, + self.k_output_scale, + self.rotary_output_scale, ) if past_key_value is not None: @@ -183,3 +189,44 @@ def forward( attn_output = self.out_proj(attn_output) return attn_output, attn_probs_reshaped, past_key_value + + +class LlamaSmoothquantMLP(nn.Module): + def __init__(self, intermediate_size, hidden_size): + super().__init__() + self.gate_proj = W8A8BFP32O32LinearSiLU(hidden_size, intermediate_size) + self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size) + self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size) + self.down_proj_input_scale = 1.0 + self.inter_out_scale = 1.0 + + def pack( + self, + mlp_module: LlamaMLP, + gate_proj_input_scale: float, + up_proj_input_scale: float, + down_proj_input_scale: float, + ): + int8_module = LlamaSmoothquantMLP( + mlp_module.intermediate_size, + mlp_module.hidden_size, + ) + + int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale) + int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale) + int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale) + self.down_proj_input_scale = down_proj_input_scale + return int8_module + + def forward( + self, + hidden_states: torch.Tensor, + ): + x_shape = hidden_states.shape + gate_out = self.gate_proj(hidden_states) + up_out = self.up_proj(hidden_states) + inter_out = gate_out * up_out + inter_out = inter_out.div_(self.inter_out_scale).round().clamp(-128, 127).to(torch.int8) + down_out = self.down_proj(inter_out) + down_out = down_out.view(*x_shape[:-1], -1) + return down_out diff --git a/tests/test_smoothquant/test_llama_mlp.py b/tests/test_smoothquant/test_llama_mlp.py new file mode 100644 index 000000000000..0cd1ce8a0b9c --- /dev/null +++ b/tests/test_smoothquant/test_llama_mlp.py @@ -0,0 +1,85 @@ +import warnings + +import pytest +import torch +from packaging import version + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except ImportError: + warnings.warn("CUDA smoothquant linear is not installed") + HAS_SMOOTHQUANT_CUDA = False + +from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP + +try: + from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP + + HAS_TORCH_INT = True +except ImportError: + HAS_TORCH_INT = False + warnings.warn("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") + + +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def torch_llama_mlp(gate_proj, up_proj, down_proj, x): + gate_out = torch.mm(x, gate_proj) + silu = torch.nn.SiLU() + gate_out = silu(gate_out) + up_out = torch.mm(x, up_proj) + + o_out = gate_out * up_out + + max_up = torch.max(torch.abs(o_out)) + min_up = torch.min(torch.abs(o_out)) + + torch_out = torch.mm(o_out, down_proj) + + return (torch_out, max_up, min_up) + + +@pytest.mark.skipif( + not CUDA_SUPPORT or not HAS_SMOOTHQUANT_CUDA or not HAS_TORCH_INT, + reason="smoothquant linear not installed properly or not install torch_int", +) +def test_linear(): + hidden_size = 256 + intermediate_size = 512 + + smooth_mlp = LlamaSmoothquantMLP(intermediate_size, hidden_size) + + smooth_mlp.gate_proj.weight = torch.ones((intermediate_size, hidden_size), dtype=torch.int8, device="cuda") + + smooth_mlp.up_proj.weight = torch.randint( + -10, 10, (intermediate_size, hidden_size), dtype=torch.int8, device="cuda" + ) + smooth_mlp.down_proj.weight = torch.randint( + -10, 10, (hidden_size, intermediate_size), dtype=torch.int8, device="cuda" + ) + + x = torch.ones((1, 256), dtype=torch.int8, device="cuda") + + torch_out, max_inter, min_inter = torch_llama_mlp( + smooth_mlp.gate_proj.weight.transpose(0, 1).to(torch.float) / hidden_size, + smooth_mlp.up_proj.weight.transpose(0, 1).to(torch.float) / 127, + smooth_mlp.down_proj.weight.transpose(0, 1).to(torch.float) / 127, + x.to(torch.float), + ) + + smooth_mlp.inter_out_scale = max_inter.item() / 127 + smooth_mlp.gate_proj.a = torch.tensor(1 / hidden_size) + smooth_mlp.up_proj.a = torch.tensor(1 / 127) + smooth_mlp.down_proj.a = torch.tensor(1 / 127 * (max_inter.item() / 127)) + + smooth_out = smooth_mlp(x) + + assert torch.allclose(torch_out, smooth_out, rtol=1e-02, atol=1e-01) + + +if __name__ == "__main__": + test_linear() From e451bb41badfb2dda0a5e3427e5a46d7829dd374 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 4 Oct 2023 16:49:18 +0800 Subject: [PATCH 2/5] fix down out scale --- colossalai/inference/quant/smoothquant/models/llama.py | 3 +-- tests/test_smoothquant/test_llama_mlp.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index ec21289be289..34449dbfe03d 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -198,7 +198,6 @@ def __init__(self, intermediate_size, hidden_size): self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size) self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size) self.down_proj_input_scale = 1.0 - self.inter_out_scale = 1.0 def pack( self, @@ -226,7 +225,7 @@ def forward( gate_out = self.gate_proj(hidden_states) up_out = self.up_proj(hidden_states) inter_out = gate_out * up_out - inter_out = inter_out.div_(self.inter_out_scale).round().clamp(-128, 127).to(torch.int8) + inter_out = inter_out.div_(self.down_proj_input_scale).round().clamp(-128, 127).to(torch.int8) down_out = self.down_proj(inter_out) down_out = down_out.view(*x_shape[:-1], -1) return down_out diff --git a/tests/test_smoothquant/test_llama_mlp.py b/tests/test_smoothquant/test_llama_mlp.py index 0cd1ce8a0b9c..5786bf1aa188 100644 --- a/tests/test_smoothquant/test_llama_mlp.py +++ b/tests/test_smoothquant/test_llama_mlp.py @@ -71,7 +71,7 @@ def test_linear(): x.to(torch.float), ) - smooth_mlp.inter_out_scale = max_inter.item() / 127 + smooth_mlp.down_proj_input_scale = max_inter.item() / 127 smooth_mlp.gate_proj.a = torch.tensor(1 / hidden_size) smooth_mlp.up_proj.a = torch.tensor(1 / 127) smooth_mlp.down_proj.a = torch.tensor(1 / 127 * (max_inter.item() / 127)) From e7696db41ed2d92dfacb215103c91601a87e84f8 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 4 Oct 2023 17:07:26 +0800 Subject: [PATCH 3/5] remove duplicate lines --- tests/test_smoothquant/test_llama_mlp.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_smoothquant/test_llama_mlp.py b/tests/test_smoothquant/test_llama_mlp.py index 5786bf1aa188..c29dbe6d5b56 100644 --- a/tests/test_smoothquant/test_llama_mlp.py +++ b/tests/test_smoothquant/test_llama_mlp.py @@ -13,8 +13,6 @@ warnings.warn("CUDA smoothquant linear is not installed") HAS_SMOOTHQUANT_CUDA = False -from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP - try: from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP From e52760d7f3ed2d806f2df991b2733425c70e3b17 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 5 Oct 2023 09:15:39 +0800 Subject: [PATCH 4/5] add llama mlp check --- tests/test_smoothquant/test_llama_mlp.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_smoothquant/test_llama_mlp.py b/tests/test_smoothquant/test_llama_mlp.py index c29dbe6d5b56..d500820847c7 100644 --- a/tests/test_smoothquant/test_llama_mlp.py +++ b/tests/test_smoothquant/test_llama_mlp.py @@ -4,20 +4,24 @@ import torch from packaging import version +from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + try: from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder smoothquant_cuda = SmoothquantBuilder().load() HAS_SMOOTHQUANT_CUDA = True -except ImportError: +except: warnings.warn("CUDA smoothquant linear is not installed") HAS_SMOOTHQUANT_CUDA = False +from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP + try: from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP HAS_TORCH_INT = True -except ImportError: +except: HAS_TORCH_INT = False warnings.warn("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") @@ -45,7 +49,7 @@ def torch_llama_mlp(gate_proj, up_proj, down_proj, x): not CUDA_SUPPORT or not HAS_SMOOTHQUANT_CUDA or not HAS_TORCH_INT, reason="smoothquant linear not installed properly or not install torch_int", ) -def test_linear(): +def test_llama_mlp(): hidden_size = 256 intermediate_size = 512 @@ -80,4 +84,4 @@ def test_linear(): if __name__ == "__main__": - test_linear() + test_llama_mlp() From 848f71fa254fc3861a2bd1fb7c0e5d357b610299 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 5 Oct 2023 09:17:34 +0800 Subject: [PATCH 5/5] delete useless code --- tests/test_smoothquant/test_llama_mlp.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_smoothquant/test_llama_mlp.py b/tests/test_smoothquant/test_llama_mlp.py index d500820847c7..ec0aaaba0198 100644 --- a/tests/test_smoothquant/test_llama_mlp.py +++ b/tests/test_smoothquant/test_llama_mlp.py @@ -4,8 +4,6 @@ import torch from packaging import version -from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder - try: from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder @@ -15,7 +13,6 @@ warnings.warn("CUDA smoothquant linear is not installed") HAS_SMOOTHQUANT_CUDA = False -from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP try: from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP