From 80ad1df2c0d5c51fa6489cf60ff36fa45ccf23a7 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 3 Oct 2023 09:46:27 +0800 Subject: [PATCH 1/5] add smoothquant llama attention --- .../quant/smoothquant/models/__init__.py | 0 .../smoothquant/models/smoothquant_layer.py | 183 ++++++++++++++++++ .../smoothquant/models/__init__.py | 0 .../smoothquant/models/smoothquant_layer.py | 183 ++++++++++++++++++ .../triton/int8_rotary_embedding_kernel.py | 18 +- .../test_smoothquant/test_llama_attention.py | 103 ++++++++++ 6 files changed, 477 insertions(+), 10 deletions(-) create mode 100644 colossalai/inference/quant/smoothquant/models/__init__.py create mode 100644 colossalai/inference/quant/smoothquant/models/smoothquant_layer.py create mode 100644 colossalai/inference/quant/smoothquant/smoothquant/models/__init__.py create mode 100644 colossalai/inference/quant/smoothquant/smoothquant/models/smoothquant_layer.py create mode 100644 tests/test_smoothquant/test_llama_attention.py diff --git a/colossalai/inference/quant/smoothquant/models/__init__.py b/colossalai/inference/quant/smoothquant/models/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/quant/smoothquant/models/smoothquant_layer.py b/colossalai/inference/quant/smoothquant/models/smoothquant_layer.py new file mode 100644 index 000000000000..ac1e430d41cd --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/smoothquant_layer.py @@ -0,0 +1,183 @@ +from typing import Optional, Tuple + +import torch +from torch import nn +from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T +from torch_int.nn.linear import W8A8B8O8Linear, W8A8BFP32OFP32Linear +from transformers.models.llama.modeling_llama import LlamaAttention + +from colossalai.kernel.triton import int8_rotary_embedding_fwd + + +class LLamaSmoothquantAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + if (self.head_dim * num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {num_heads})." + ) + + self.attention_weight_scale = 1.0 + + self.qk_bmm = BMM_S8T_S8N_F32T(1.0) + self.pv_bmm = BMM_S8T_S8N_S8T(1.0) + + self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.out_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size) + + self.q_output_scale = torch.tensor([1.0]) + self.k_output_scale = torch.tensor([1.0]) + self.rotary_output_scale = torch.tensor([1.0]) + + def pack( + self, + module: LlamaAttention, + input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + out_input_scale: float, + rotary_output_scale: float, + ): + int8_module = LLamaSmoothquantAttention(module.hidden_size, module.head_dim) + int8_module.q_output_scale = q_output_scale + int8_module.k_output_scale = k_output_scale + int8_module.rotary_output_scale = rotary_output_scale + q_output_scale = q_output_scale * module.scaling + module.q_proj.weight *= module.scaling + module.q_proj.bias *= module.scaling + int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, input_scale, q_output_scale) + + int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, input_scale, k_output_scale) + int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, input_scale, v_output_scale) + int8_module.out_proj = W8A8BFP32OFP32Linear.from_float(module.out_proj, out_input_scale) + int8_module.qk_bmm = BMM_S8T_S8N_F32T.from_scale(q_output_scale, k_output_scale) + + # alpha = s_prob * s_v / s_out, where s_prob = 1 / 127 + int8_module.pv_bmm = BMM_S8T_S8N_S8T.from_scale(1.0 / 127, v_output_scale, out_input_scale) + return int8_module + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + @torch.no_grad() + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: Tuple[torch.Tensor], + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, seq_len, _ = hidden_states.size() + # get query proj + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + cos = rotary_emb[0] + sin = rotary_emb[1] + int8_rotary_embedding_fwd( + query_states.view(-1, self.num_heads, self.head_dim), + cos, + sin, + self.q_output_scale, + self.rotary_output_scale, + ) + int8_rotary_embedding_fwd( + key_states.view(-1, self.num_heads, self.head_dim), cos, sin, self.k_output_scale, self.rotary_output_scale + ) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(key_states, -1, bsz) + value_states = self._shape(value_states, -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(key_states, -1, bsz) + value_states = self._shape(value_states, -1, bsz) + + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + + query_states = self._shape(query_states, seq_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = self.qk_bmm(query_states, key_states) + + if attn_weights.size() != (bsz * self.num_heads, seq_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, seq_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, seq_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, seq_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, seq_len, src_len) + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = attn_weights.view(bsz * self.num_heads, seq_len, src_len) + + attn_probs = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_probs = layer_head_mask.view(1, -1, 1, 1) * attn_probs.view(bsz, self.num_heads, seq_len, src_len) + attn_probs = attn_probs.view(bsz * self.num_heads, seq_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_probs_reshaped = attn_probs.view(bsz, self.num_heads, seq_len, src_len) + attn_probs = attn_probs_reshaped.view(bsz * self.num_heads, seq_len, src_len) + else: + attn_probs_reshaped = None + + # (A_row V_row)_row = (A_row V_col ^T)_row + attn_probs.mul_(127).round_() + attn_probs = attn_probs.to(torch.int8) + + value_states = value_states.transpose(1, 2).contiguous() + attn_output = self.pv_bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, seq_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, seq_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, seq_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, seq_len, self.num_heads * self.head_dim).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_probs_reshaped, past_key_value diff --git a/colossalai/inference/quant/smoothquant/smoothquant/models/__init__.py b/colossalai/inference/quant/smoothquant/smoothquant/models/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/quant/smoothquant/smoothquant/models/smoothquant_layer.py b/colossalai/inference/quant/smoothquant/smoothquant/models/smoothquant_layer.py new file mode 100644 index 000000000000..122b7320e07d --- /dev/null +++ b/colossalai/inference/quant/smoothquant/smoothquant/models/smoothquant_layer.py @@ -0,0 +1,183 @@ +from typing import Optional, Tuple + +import torch +from torch import nn +from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T +from torch_int.nn.linear import W8A8B8O8Linear, W8A8BFP32OFP32Linear +from transformers.models.llama.modeling_llama import LlamaAttention + +from colossalai.kernel.triton import int8_rotary_embedding_fwd + + +class LLamaSmoothquantAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + if (self.head_dim * num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {num_heads})." + ) + + self.attention_weight_scale = 1.0 + + self.qk_bmm = BMM_S8T_S8N_F32T(1.0) + self.pv_bmm = BMM_S8T_S8N_S8T(1.0) + + self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.out_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size) + + self.q_output_scale = torch.tensor([1.0]) + self.k_output_scale = torch.tensor([1.0]) + self.rotary_output_scale = torch.tensor([1.0]) + + def pack( + self, + module: LlamaAttention, + input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + out_input_scale: float, + rotary_output_scale: float, + ): + int8_module = LLamaSmoothquantAttention(module.hidden_size, module.head_dim) + int8_module.q_output_scale = q_output_scale + int8_module.k_output_scale = k_output_scale + int8_module.rotary_output_scale = rotary_output_scale + q_output_scale = q_output_scale * module.scaling + module.q_proj.weight *= module.scaling + module.q_proj.bias *= module.scaling + int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, input_scale, q_output_scale) + + int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, input_scale, k_output_scale) + int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, input_scale, v_output_scale) + int8_module.out_proj = W8A8BFP32OFP32Linear.from_float(module.out_proj, out_input_scale) + int8_module.qk_bmm = BMM_S8T_S8N_F32T.from_scale(q_output_scale, k_output_scale) + + # alpha = s_prob * s_v / s_out, where s_prob = 1 / 127 + int8_module.pv_bmm = BMM_S8T_S8N_S8T.from_scale(1.0 / 127, v_output_scale, out_input_scale) + return int8_module + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + @torch.no_grad() + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: Tuple[torch.Tensor], + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, seq_len, _ = hidden_states.size() + # get query proj + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + # TODO: rotary embedding + cos = rotary_emb[0] + sin = rotary_emb[1] + int8_rotary_embedding_fwd( + query_states.view(-1, self.num_heads, self.head_dim), + cos, + sin, + self.q_output_scale, + self.rotary_output_scale, + ) + int8_rotary_embedding_fwd( + key_states.view(-1, self.num_heads, self.head_dim), cos, sin, self.k_output_scale, self.rotary_output_scale + ) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(key_states, -1, bsz) + value_states = self._shape(value_states, -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(key_states, -1, bsz) + value_states = self._shape(value_states, -1, bsz) + + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + + query_states = self._shape(query_states, seq_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = self.qk_bmm(query_states, key_states) + + if attn_weights.size() != (bsz * self.num_heads, seq_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, seq_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, seq_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, seq_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, seq_len, src_len) + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = attn_weights.view(bsz * self.num_heads, seq_len, src_len) + + attn_probs = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_probs = layer_head_mask.view(1, -1, 1, 1) * attn_probs.view(bsz, self.num_heads, seq_len, src_len) + attn_probs = attn_probs.view(bsz * self.num_heads, seq_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_probs_reshaped = attn_probs.view(bsz, self.num_heads, seq_len, src_len) + attn_probs = attn_probs_reshaped.view(bsz * self.num_heads, seq_len, src_len) + else: + attn_probs_reshaped = None + + # (A_row V_row)_row = (A_row V_col ^T)_row + attn_probs.mul_(127).round_() + attn_probs = attn_probs.to(torch.int8) + + value_states = value_states.transpose(1, 2).contiguous() + attn_output = self.pv_bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, seq_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, seq_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, seq_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, seq_len, self.num_heads * self.head_dim).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_probs_reshaped, past_key_value diff --git a/colossalai/kernel/triton/int8_rotary_embedding_kernel.py b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py index 1e2c5c427954..dfad8a973ed6 100644 --- a/colossalai/kernel/triton/int8_rotary_embedding_kernel.py +++ b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py @@ -57,17 +57,15 @@ def _rotary_kernel( cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) - in_scale = tl.load(input_scale) - o_scale = tl.load(output_scale) - q0 = q0.to(tl.float32) * in_scale - q1 = q1.to(tl.float32) * in_scale + q0 = q0.to(tl.float32) * input_scale + q1 = q1.to(tl.float32) * input_scale - out0 = (q0 * cos - q1 * sin) / o_scale - out1 = (q0 * sin + q1 * cos) / o_scale + out0 = (q0 * cos - q1 * sin) / output_scale + out1 = (q0 * sin + q1 * cos) / output_scale - # out0 = out0.to(tl.int8) - # out1 = out1.to(tl.int8) + out0 = out0.to(tl.int8) + out1 = out1.to(tl.int8) tl.store( q + off_q0, @@ -99,8 +97,8 @@ def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale): _rotary_kernel[grid]( q, - input_scale, - output_scale, + input_scale.item(), + output_scale.item(), cos, sin, q.stride(0), diff --git a/tests/test_smoothquant/test_llama_attention.py b/tests/test_smoothquant/test_llama_attention.py new file mode 100644 index 000000000000..c4f111aaff2a --- /dev/null +++ b/tests/test_smoothquant/test_llama_attention.py @@ -0,0 +1,103 @@ +import pytest +import torch +from packaging import version + +from colossalai.inference.quant.smoothquant.models.smoothquant_layer import LLamaSmoothquantAttention +from colossalai.kernel.triton import int8_rotary_embedding_fwd + +try: + from colossalai.inference.quant.smoothquant.models.smoothquant_layer import LLamaSmoothquantAttention + from colossalai.kernel.triton import int8_rotary_embedding_fwd + + # /home/lcxk/data3/test_tp_infer/ColossalAI/colossalai/inference/quant/smoothquant/models/smoothquant_layer.py + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + +import math + +import torch +from torch.nn import functional as F + + +def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): + """ + adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 + """ + xq = xq.view(bs, seqlen, num_head, head_dim) + xk = xk.view(bs, seqlen, num_head, head_dim) + xv = xv.view(bs, seqlen, num_head, head_dim) + mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() + mask[mask == 0.0] = -100000000.0 + mask = mask.repeat(bs, num_head, 1, 1) + keys = xk + values = xv + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + sm_scale = 1 / math.sqrt(head_dim) + scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale + scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float) + + output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) + return output + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test_llama_context_attention(): + head_num = 8 + seq_len = 32 + head_dim = 64 + dtype = torch.float + hidden_size = head_num * head_dim + + smooth_attn = LLamaSmoothquantAttention(head_num * head_dim, head_num) + + smooth_attn.q_proj.weight = torch.ones(hidden_size, hidden_size).to(torch.int8) + smooth_attn.k_proj.weight = torch.ones(hidden_size, hidden_size).to(torch.int8) + smooth_attn.v_proj.weight = torch.ones(hidden_size, hidden_size).to(torch.int8) + smooth_attn.out_proj.weight = torch.ones(hidden_size, hidden_size).to(torch.int8) + + smooth_attn = smooth_attn.to("cuda") + + input = torch.randint(-127, 127, (1, seq_len, head_num * head_dim), dtype=torch.int8, device="cuda") + + q = smooth_attn.q_proj(input) + k = smooth_attn.k_proj(input) + v = smooth_attn.v_proj(input) + + cos_shape = (seq_len, head_dim // 2) + cos = torch.ones(cos_shape, dtype=dtype, device="cuda") + sin = torch.zeros(cos_shape, dtype=dtype, device="cuda") + + in_scale = torch.tensor([1.0], device="cuda") + out_scale = torch.tensor([1.0], device="cuda") + + int8_rotary_embedding_fwd(q.view(-1, head_num, head_dim), cos, sin, in_scale, out_scale) + int8_rotary_embedding_fwd(k.view(-1, head_num, head_dim), cos, sin, in_scale, out_scale) + + q = q.to(torch.float) + k = k.to(torch.float) + v = v.to(torch.float) + torch_out = torch_context_attention(q.clone(), k.clone(), v.clone(), 1, seq_len, head_num, head_dim) + torch_out = (torch_out).to(torch.int8).view(-1, seq_len, head_num * head_dim) + torch_out = smooth_attn.out_proj(torch_out) + smooth_out, _, _ = smooth_attn(input, (cos, sin)) + smooth_out = smooth_out.to(torch.float) + torch_out = torch_out.to(torch.float) + + assert torch.allclose( + smooth_out.cpu(), torch_out.cpu(), rtol=1e-2, atol=1e-2 + ), "outputs from triton and torch are not matched" + + +if __name__ == "__main__": + test_llama_context_attention() From 6f4060a7c7115843ba98c8699ea81fa5de8c876f Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 3 Oct 2023 09:54:25 +0800 Subject: [PATCH 2/5] remove uselss code --- .../smoothquant/models/smoothquant_layer.py | 2 + .../smoothquant/models/__init__.py | 0 .../smoothquant/models/smoothquant_layer.py | 183 ------------------ 3 files changed, 2 insertions(+), 183 deletions(-) delete mode 100644 colossalai/inference/quant/smoothquant/smoothquant/models/__init__.py delete mode 100644 colossalai/inference/quant/smoothquant/smoothquant/models/smoothquant_layer.py diff --git a/colossalai/inference/quant/smoothquant/models/smoothquant_layer.py b/colossalai/inference/quant/smoothquant/models/smoothquant_layer.py index ac1e430d41cd..b021e604ab09 100644 --- a/colossalai/inference/quant/smoothquant/models/smoothquant_layer.py +++ b/colossalai/inference/quant/smoothquant/models/smoothquant_layer.py @@ -1,3 +1,5 @@ +# Code modified from smoothquant: https://github.com/mit-han-lab/smoothquant + from typing import Optional, Tuple import torch diff --git a/colossalai/inference/quant/smoothquant/smoothquant/models/__init__.py b/colossalai/inference/quant/smoothquant/smoothquant/models/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/colossalai/inference/quant/smoothquant/smoothquant/models/smoothquant_layer.py b/colossalai/inference/quant/smoothquant/smoothquant/models/smoothquant_layer.py deleted file mode 100644 index 122b7320e07d..000000000000 --- a/colossalai/inference/quant/smoothquant/smoothquant/models/smoothquant_layer.py +++ /dev/null @@ -1,183 +0,0 @@ -from typing import Optional, Tuple - -import torch -from torch import nn -from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T -from torch_int.nn.linear import W8A8B8O8Linear, W8A8BFP32OFP32Linear -from transformers.models.llama.modeling_llama import LlamaAttention - -from colossalai.kernel.triton import int8_rotary_embedding_fwd - - -class LLamaSmoothquantAttention(nn.Module): - def __init__( - self, - hidden_size: int, - num_heads: int, - ): - super().__init__() - self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = hidden_size // num_heads - - if (self.head_dim * num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {num_heads})." - ) - - self.attention_weight_scale = 1.0 - - self.qk_bmm = BMM_S8T_S8N_F32T(1.0) - self.pv_bmm = BMM_S8T_S8N_S8T(1.0) - - self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size) - self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size) - self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size) - self.out_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size) - - self.q_output_scale = torch.tensor([1.0]) - self.k_output_scale = torch.tensor([1.0]) - self.rotary_output_scale = torch.tensor([1.0]) - - def pack( - self, - module: LlamaAttention, - input_scale: float, - q_output_scale: float, - k_output_scale: float, - v_output_scale: float, - out_input_scale: float, - rotary_output_scale: float, - ): - int8_module = LLamaSmoothquantAttention(module.hidden_size, module.head_dim) - int8_module.q_output_scale = q_output_scale - int8_module.k_output_scale = k_output_scale - int8_module.rotary_output_scale = rotary_output_scale - q_output_scale = q_output_scale * module.scaling - module.q_proj.weight *= module.scaling - module.q_proj.bias *= module.scaling - int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, input_scale, q_output_scale) - - int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, input_scale, k_output_scale) - int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, input_scale, v_output_scale) - int8_module.out_proj = W8A8BFP32OFP32Linear.from_float(module.out_proj, out_input_scale) - int8_module.qk_bmm = BMM_S8T_S8N_F32T.from_scale(q_output_scale, k_output_scale) - - # alpha = s_prob * s_v / s_out, where s_prob = 1 / 127 - int8_module.pv_bmm = BMM_S8T_S8N_S8T.from_scale(1.0 / 127, v_output_scale, out_input_scale) - return int8_module - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - @torch.no_grad() - def forward( - self, - hidden_states: torch.Tensor, - rotary_emb: Tuple[torch.Tensor], - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, seq_len, _ = hidden_states.size() - # get query proj - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - # TODO: rotary embedding - cos = rotary_emb[0] - sin = rotary_emb[1] - int8_rotary_embedding_fwd( - query_states.view(-1, self.num_heads, self.head_dim), - cos, - sin, - self.q_output_scale, - self.rotary_output_scale, - ) - int8_rotary_embedding_fwd( - key_states.view(-1, self.num_heads, self.head_dim), cos, sin, self.k_output_scale, self.rotary_output_scale - ) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(key_states, -1, bsz) - value_states = self._shape(value_states, -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(key_states, -1, bsz) - value_states = self._shape(value_states, -1, bsz) - - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - - query_states = self._shape(query_states, seq_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = self.qk_bmm(query_states, key_states) - - if attn_weights.size() != (bsz * self.num_heads, seq_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, seq_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, seq_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, seq_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, seq_len, src_len) + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - attn_weights = attn_weights.view(bsz * self.num_heads, seq_len, src_len) - - attn_probs = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_probs = layer_head_mask.view(1, -1, 1, 1) * attn_probs.view(bsz, self.num_heads, seq_len, src_len) - attn_probs = attn_probs.view(bsz * self.num_heads, seq_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_probs_reshaped = attn_probs.view(bsz, self.num_heads, seq_len, src_len) - attn_probs = attn_probs_reshaped.view(bsz * self.num_heads, seq_len, src_len) - else: - attn_probs_reshaped = None - - # (A_row V_row)_row = (A_row V_col ^T)_row - attn_probs.mul_(127).round_() - attn_probs = attn_probs.to(torch.int8) - - value_states = value_states.transpose(1, 2).contiguous() - attn_output = self.pv_bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, seq_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, seq_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, seq_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, seq_len, self.num_heads * self.head_dim).contiguous() - attn_output = self.out_proj(attn_output) - - return attn_output, attn_probs_reshaped, past_key_value From fa5ee22d9be52e16e62355e4693aaa3239414c16 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 3 Oct 2023 10:01:12 +0800 Subject: [PATCH 3/5] remove useless code --- tests/test_smoothquant/test_llama_attention.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_smoothquant/test_llama_attention.py b/tests/test_smoothquant/test_llama_attention.py index c4f111aaff2a..3291fe62d0e7 100644 --- a/tests/test_smoothquant/test_llama_attention.py +++ b/tests/test_smoothquant/test_llama_attention.py @@ -9,8 +9,6 @@ from colossalai.inference.quant.smoothquant.models.smoothquant_layer import LLamaSmoothquantAttention from colossalai.kernel.triton import int8_rotary_embedding_fwd - # /home/lcxk/data3/test_tp_infer/ColossalAI/colossalai/inference/quant/smoothquant/models/smoothquant_layer.py - HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -28,7 +26,7 @@ def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): """ - adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 + adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 """ xq = xq.view(bs, seqlen, num_head, head_dim) xk = xk.view(bs, seqlen, num_head, head_dim) From e9e914f32c8827aa7e573f810a8b35087979dc25 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 3 Oct 2023 10:55:48 +0800 Subject: [PATCH 4/5] fix import error --- .../quant/smoothquant/models/__init__.py | 12 ++++++++++++ tests/test_smoothquant/test_llama_attention.py | 16 ++++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/colossalai/inference/quant/smoothquant/models/__init__.py b/colossalai/inference/quant/smoothquant/models/__init__.py index e69de29bb2d1..9382a10cffa6 100644 --- a/colossalai/inference/quant/smoothquant/models/__init__.py +++ b/colossalai/inference/quant/smoothquant/models/__init__.py @@ -0,0 +1,12 @@ +try: + import torch_int + + HAS_TORCH_INT = True +except ImportError: + HAS_TORCH_INT = False + raise ImportError( + "Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int" + ) + +if HAS_TORCH_INT: + from .smoothquant_layer import LLamaSmoothquantAttention diff --git a/tests/test_smoothquant/test_llama_attention.py b/tests/test_smoothquant/test_llama_attention.py index 3291fe62d0e7..26f35e20c6b2 100644 --- a/tests/test_smoothquant/test_llama_attention.py +++ b/tests/test_smoothquant/test_llama_attention.py @@ -2,11 +2,7 @@ import torch from packaging import version -from colossalai.inference.quant.smoothquant.models.smoothquant_layer import LLamaSmoothquantAttention -from colossalai.kernel.triton import int8_rotary_embedding_fwd - try: - from colossalai.inference.quant.smoothquant.models.smoothquant_layer import LLamaSmoothquantAttention from colossalai.kernel.triton import int8_rotary_embedding_fwd HAS_TRITON = True @@ -14,7 +10,14 @@ HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") +try: + from colossalai.inference.quant.smoothquant.models import LLamaSmoothquantAttention + + HAS_TORCH_INT = True +except ImportError: + HAS_TORCH_INT = False + print("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") + TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") @@ -48,7 +51,8 @@ def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): @pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" + not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_TORCH_INT, + reason="triton requires cuda version to be higher than 11.4 or not install torch_int", ) def test_llama_context_attention(): head_num = 8 From 9979951bdbecc0978e81cc80b65de6e19e078ecc Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 3 Oct 2023 12:06:09 +0800 Subject: [PATCH 5/5] rename file name --- .../{test_rotary_embedding.py => test_sq_rotary_embedding.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/test_smoothquant/{test_rotary_embedding.py => test_sq_rotary_embedding.py} (100%) diff --git a/tests/test_smoothquant/test_rotary_embedding.py b/tests/test_smoothquant/test_sq_rotary_embedding.py similarity index 100% rename from tests/test_smoothquant/test_rotary_embedding.py rename to tests/test_smoothquant/test_sq_rotary_embedding.py