diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py index 73cdbb39e53f..13d540bfb315 100644 --- a/colossalai/inference/quant/smoothquant/models/base_model.py +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -14,7 +14,6 @@ import torch.nn as nn import transformers from safetensors.torch import save_file as safe_save -from torch import device from tqdm import tqdm from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel from transformers.modeling_utils import no_init_weights @@ -24,8 +23,6 @@ from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager -CPU = device("cpu") - SUPPORTED_MODELS = ["llama"] @@ -204,7 +201,7 @@ def save_quantized( if not self.quantized: raise EnvironmentError("can only save quantized model, please execute .quantize first.") - self.model.to(CPU) + self.model.to("cpu") model_base_name = model_basename # or f"smooth-" if use_safetensors: @@ -431,7 +428,7 @@ def from_quantized( model_save_name = resolved_archive_file - # == step2: convert model to gptq-model (replace Linear with QuantLinear) == # + # == step2: convert model to quantized-model (replace Linear) == # def skip(*args, **kwargs): pass @@ -463,10 +460,10 @@ def skip(*args, **kwargs): model.model.register_buffer("_sin_cached", sin) model.tie_weights() + # == step3: load checkpoint of to quantized-model == # accelerate.utils.modeling.load_checkpoint_in_model( model, checkpoint=model_save_name, offload_state_dict=True, offload_buffers=True ) - model = model.to("cuda") # == step4: set seqlen == # model_config = model.config.to_dict() diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index 014fb640e060..4a14066b8379 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -1,5 +1,3 @@ -# Code modified from smoothquant: https://github.com/mit-han-lab/smoothquant - import math import os import types @@ -92,7 +90,7 @@ def pack( out_input_scale: float, ): int8_module = LLamaSmoothquantAttention(module.hidden_size, module.num_heads) - # self.register_buffer("attn_input_scale", torch.tensor([1.0])) + int8_module.attn_input_scale = torch.tensor([attn_input_scale]) int8_module.q_output_scale = torch.tensor([q_output_scale]) @@ -107,10 +105,6 @@ def pack( int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale) int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale) - # int8_module.q_proj = module.q_proj - # int8_module.k_proj = module.k_proj - # int8_module.v_proj = module.v_proj - # int8_module.o_proj = module.o_proj int8_module.out_input_scale = torch.tensor([out_input_scale]) return int8_module @@ -259,10 +253,8 @@ def forward(self, x): @staticmethod def from_float(module: torch.nn.LayerNorm, output_scale: float): assert module.weight.shape[0] == module.weight.numel() - # assert module.bias.shape[0] == module.bias.numel() q_module = LlamaLayerNormQ(module.weight.shape[0], module.variance_epsilon) q_module.weight = module.weight / output_scale - # q_module.bias = module.bias / output_scale return q_module @@ -346,9 +338,6 @@ def pack( out_input_scale, ) - # int8_decoder_layer.input_layernorm = module.input_layernorm - # int8_decoder_layer.self_attn = module.self_attn - int8_decoder_layer.post_attention_layernorm = LlamaLayerNormQ.from_float( module.post_attention_layernorm, gate_input_scale ) @@ -360,9 +349,6 @@ def pack( down_input_scale, ) - # int8_decoder_layer.post_attention_layernorm = module.post_attention_layernorm - # int8_decoder_layer.mlp = module.mlp - return int8_decoder_layer def forward( @@ -641,8 +627,6 @@ def llama_model_forward( infer_state.decode_is_contiguous = False alloc_mem = infer_state.cache_manager.alloc(batch_size) infer_state.decode_mem_index = alloc_mem - # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index if position_ids is None: @@ -673,11 +657,7 @@ def llama_model_forward( hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False + raise NotImplementedError("not implement gradient_checkpointing and training options ") if past_key_values_length == 0: position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( @@ -701,20 +681,17 @@ def llama_model_forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - raise NotImplementedError("not implement gradient_checkpointing and training options ") - else: - layer_outputs = decoder_layer( - hidden_states, - rotary_emb=(position_cos, position_sin), - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - padding_mask=padding_mask, - infer_state=infer_state, - ) + layer_outputs = decoder_layer( + hidden_states, + rotary_emb=(position_cos, position_sin), + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + infer_state=infer_state, + ) hidden_states = layer_outputs[0] infer_state.decode_layer_id += 1 @@ -836,13 +813,12 @@ def quantized( scale_dict["q_rotary_output_scale"] = ( act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127 ) - scale_dict["k_rotary_output_scale"] = ( act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127 ) scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127 - # mlp scales + scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127 scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127 scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 diff --git a/examples/inference/smoothquant_llama.py b/examples/inference/smoothquant_llama.py index 96f6e3730ebf..b27214b5cee2 100644 --- a/examples/inference/smoothquant_llama.py +++ b/examples/inference/smoothquant_llama.py @@ -2,6 +2,7 @@ import os import torch +from datasets import load_dataset from transformers import LlamaTokenizer from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM @@ -47,13 +48,15 @@ def main(): if not os.path.exists(dataset_path): print(f"Cannot find the dataset at {args.dataset_path}") raise FileNotFoundError + dataset = dataset = load_dataset("json", data_files=dataset_path, split="train") + model.quantized(tokenizer, dataset, num_samples=num_samples, seq_len=seq_len) model = model.cuda() - model.quantized(tokenizer, dataset_path, num_samples=num_samples, seq_len=seq_len) model.save_quantized(output_path, model_basename="llama-7b") model = SmoothLlamaForCausalLM.from_quantized(output_path, model_basename="llama-7b") + model = model.cuda() generate_kwargs = dict(max_new_tokens=16, do_sample=False, use_cache=True) input_tokens = tokenizer(["today is "], return_tensors="pt").to("cuda")