diff --git a/colossalai/inference/quant/smoothquant/models/__init__.py b/colossalai/inference/quant/smoothquant/models/__init__.py index 9382a10cffa6..77541d8610c5 100644 --- a/colossalai/inference/quant/smoothquant/models/__init__.py +++ b/colossalai/inference/quant/smoothquant/models/__init__.py @@ -9,4 +9,4 @@ ) if HAS_TORCH_INT: - from .smoothquant_layer import LLamaSmoothquantAttention + from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP diff --git a/colossalai/inference/quant/smoothquant/models/linear.py b/colossalai/inference/quant/smoothquant/models/linear.py new file mode 100644 index 000000000000..1c01c6222e7a --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/linear.py @@ -0,0 +1,49 @@ +import torch +from torch_int.functional.quantization import quantize_per_tensor_absmax + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except ImportError: + HAS_SMOOTHQUANT_CUDA = False + raise ImportError("CUDA smoothquant linear is not installed") + + +class W8A8BFP32O32LinearSiLU(torch.nn.Module): + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint(-127, 127, (self.out_features, self.in_features), dtype=torch.int8, requires_grad=False), + ) + self.register_buffer("bias", torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False)) + self.register_buffer("a", torch.tensor(alpha)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight + int8_module.bias.data.copy_(module.bias.to(torch.float)) + int8_module.a = alpha + return int8_module diff --git a/colossalai/inference/quant/smoothquant/models/smoothquant_layer.py b/colossalai/inference/quant/smoothquant/models/llama.py similarity index 81% rename from colossalai/inference/quant/smoothquant/models/smoothquant_layer.py rename to colossalai/inference/quant/smoothquant/models/llama.py index b021e604ab09..34449dbfe03d 100644 --- a/colossalai/inference/quant/smoothquant/models/smoothquant_layer.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -6,10 +6,12 @@ from torch import nn from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T from torch_int.nn.linear import W8A8B8O8Linear, W8A8BFP32OFP32Linear -from transformers.models.llama.modeling_llama import LlamaAttention +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP from colossalai.kernel.triton import int8_rotary_embedding_fwd +from .linear import W8A8BFP32O32LinearSiLU + class LLamaSmoothquantAttention(nn.Module): def __init__( @@ -100,7 +102,11 @@ def forward( self.rotary_output_scale, ) int8_rotary_embedding_fwd( - key_states.view(-1, self.num_heads, self.head_dim), cos, sin, self.k_output_scale, self.rotary_output_scale + key_states.view(-1, self.num_heads, self.head_dim), + cos, + sin, + self.k_output_scale, + self.rotary_output_scale, ) if past_key_value is not None: @@ -183,3 +189,43 @@ def forward( attn_output = self.out_proj(attn_output) return attn_output, attn_probs_reshaped, past_key_value + + +class LlamaSmoothquantMLP(nn.Module): + def __init__(self, intermediate_size, hidden_size): + super().__init__() + self.gate_proj = W8A8BFP32O32LinearSiLU(hidden_size, intermediate_size) + self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size) + self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size) + self.down_proj_input_scale = 1.0 + + def pack( + self, + mlp_module: LlamaMLP, + gate_proj_input_scale: float, + up_proj_input_scale: float, + down_proj_input_scale: float, + ): + int8_module = LlamaSmoothquantMLP( + mlp_module.intermediate_size, + mlp_module.hidden_size, + ) + + int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale) + int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale) + int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale) + self.down_proj_input_scale = down_proj_input_scale + return int8_module + + def forward( + self, + hidden_states: torch.Tensor, + ): + x_shape = hidden_states.shape + gate_out = self.gate_proj(hidden_states) + up_out = self.up_proj(hidden_states) + inter_out = gate_out * up_out + inter_out = inter_out.div_(self.down_proj_input_scale).round().clamp(-128, 127).to(torch.int8) + down_out = self.down_proj(inter_out) + down_out = down_out.view(*x_shape[:-1], -1) + return down_out diff --git a/tests/test_smoothquant/test_llama_mlp.py b/tests/test_smoothquant/test_llama_mlp.py new file mode 100644 index 000000000000..ec0aaaba0198 --- /dev/null +++ b/tests/test_smoothquant/test_llama_mlp.py @@ -0,0 +1,84 @@ +import warnings + +import pytest +import torch +from packaging import version + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except: + warnings.warn("CUDA smoothquant linear is not installed") + HAS_SMOOTHQUANT_CUDA = False + + +try: + from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP + + HAS_TORCH_INT = True +except: + HAS_TORCH_INT = False + warnings.warn("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") + + +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def torch_llama_mlp(gate_proj, up_proj, down_proj, x): + gate_out = torch.mm(x, gate_proj) + silu = torch.nn.SiLU() + gate_out = silu(gate_out) + up_out = torch.mm(x, up_proj) + + o_out = gate_out * up_out + + max_up = torch.max(torch.abs(o_out)) + min_up = torch.min(torch.abs(o_out)) + + torch_out = torch.mm(o_out, down_proj) + + return (torch_out, max_up, min_up) + + +@pytest.mark.skipif( + not CUDA_SUPPORT or not HAS_SMOOTHQUANT_CUDA or not HAS_TORCH_INT, + reason="smoothquant linear not installed properly or not install torch_int", +) +def test_llama_mlp(): + hidden_size = 256 + intermediate_size = 512 + + smooth_mlp = LlamaSmoothquantMLP(intermediate_size, hidden_size) + + smooth_mlp.gate_proj.weight = torch.ones((intermediate_size, hidden_size), dtype=torch.int8, device="cuda") + + smooth_mlp.up_proj.weight = torch.randint( + -10, 10, (intermediate_size, hidden_size), dtype=torch.int8, device="cuda" + ) + smooth_mlp.down_proj.weight = torch.randint( + -10, 10, (hidden_size, intermediate_size), dtype=torch.int8, device="cuda" + ) + + x = torch.ones((1, 256), dtype=torch.int8, device="cuda") + + torch_out, max_inter, min_inter = torch_llama_mlp( + smooth_mlp.gate_proj.weight.transpose(0, 1).to(torch.float) / hidden_size, + smooth_mlp.up_proj.weight.transpose(0, 1).to(torch.float) / 127, + smooth_mlp.down_proj.weight.transpose(0, 1).to(torch.float) / 127, + x.to(torch.float), + ) + + smooth_mlp.down_proj_input_scale = max_inter.item() / 127 + smooth_mlp.gate_proj.a = torch.tensor(1 / hidden_size) + smooth_mlp.up_proj.a = torch.tensor(1 / 127) + smooth_mlp.down_proj.a = torch.tensor(1 / 127 * (max_inter.item() / 127)) + + smooth_out = smooth_mlp(x) + + assert torch.allclose(torch_out, smooth_out, rtol=1e-02, atol=1e-01) + + +if __name__ == "__main__": + test_llama_mlp()