From 3257e0e1cd453fdddbfbea7af2bd7a1d992d5aff Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 13 Sep 2023 15:24:53 +0800 Subject: [PATCH 1/7] combine gptq and kv cache manager --- colossalai/gptq/cai_gptq/cai_quant_linear.py | 54 +--- colossalai/gptq/gptq_tp.py | 2 +- .../inference/tensor_parallel/engine.py | 62 +++++ .../tensor_parallel/policies/bloom.py | 21 +- .../tensor_parallel/policies/llama.py | 37 ++- examples/inference/gptq_bloom.py | 131 +++++++++ examples/inference/gptq_llama.py | 257 ++++++++++++++---- tests/test_gptq/test_gptq_linear.py | 49 ++++ 8 files changed, 487 insertions(+), 126 deletions(-) create mode 100644 examples/inference/gptq_bloom.py diff --git a/colossalai/gptq/cai_gptq/cai_quant_linear.py b/colossalai/gptq/cai_gptq/cai_quant_linear.py index 78a37e7bbfb3..93312716992d 100644 --- a/colossalai/gptq/cai_gptq/cai_quant_linear.py +++ b/colossalai/gptq/cai_gptq/cai_quant_linear.py @@ -147,49 +147,6 @@ def pack(self, linear, scales, zeros, g_idx=None): else: self.g_idx = g_idx - def prepare_buffers(self): - assert self.qweight.device.type == "cuda" - device = self.qweight.device - if self.g_idx is not None: - if self.row_split and torch.equal( - self.g_idx, - torch.tensor( - [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], - dtype=torch.int32, - device=self.g_idx.device)): - self.g_idx = None - elif torch.equal( - self.g_idx, - torch.tensor([i // self.groupsize for i in range(self.infeatures)], - dtype=torch.int32, - device=self.g_idx.device)): - self.g_idx = None - - CaiQuantLinear.max_dq_buffer_size = max(CaiQuantLinear.max_dq_buffer_size, self.qweight.numel() * 8) - - if self.g_idx is not None: - CaiQuantLinear.max_inner_outer_dim = max(CaiQuantLinear.max_inner_outer_dim, self.infeatures, - self.outfeatures) - CaiQuantLinear.max_input_len = 4096 - # The temp_state buffer is required to reorder X in the act-order case. - # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - CaiQuantLinear.device_to_buffers['temp_state'] = torch.zeros( - (CaiQuantLinear.max_input_len, CaiQuantLinear.max_inner_outer_dim), dtype=torch.float16, device=device) - CaiQuantLinear.device_to_buffers['temp_dp'] = torch.zeros((1, CaiQuantLinear.max_dq_buffer_size), - dtype=torch.float16, - device=device) - - gptq_cuda.prepare_buffers(torch.device(device), CaiQuantLinear.device_to_buffers['temp_state'], - CaiQuantLinear.device_to_buffers['temp_dp']) - - # Using the default from exllama repo here. - matmul_recons_thd = 8 - matmul_fused_remap = False - matmul_no_half2 = False - gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) - - torch.cuda.empty_cache() - def init_q4(self): assert self.qweight.device.type == "cuda" self.q4_width = self.qweight.shape[1] @@ -219,21 +176,18 @@ def init_q4(self): def forward(self, x): outshape = x.shape[:-1] + (self.outfeatures,) - if HAS_GPTQ_CUDA: - if CaiQuantLinear.prepared_buffers == False: - self.prepare_buffers() - CaiQuantLinear.prepared_buffers = True + if HAS_GPTQ_CUDA and self.bits == 4: if self.q4 is None: self.init_q4() x = x.view(-1, x.shape[-1]) output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device) - gptq_cuda.q4_matmul(x, self.q4, output) - if (self.bias is not None and not self.row_split) or self.tp_size == 1: + gptq_cuda.q4_matmul(x.half(), self.q4, output) + if self.bias is not None and (not self.row_split or self.tp_size == 1): output.add_(self.bias) else: - if (self.bias is not None and not self.row_split) or self.tp_size == 1: + if self.bias is not None and (not self.row_split or self.tp_size == 1): bias = self.bias else: bias = None diff --git a/colossalai/gptq/gptq_tp.py b/colossalai/gptq/gptq_tp.py index e8d1d7f00fe8..cc6d184da458 100644 --- a/colossalai/gptq/gptq_tp.py +++ b/colossalai/gptq/gptq_tp.py @@ -95,7 +95,7 @@ def all_reduce_hook(cai_linear, input, output): model_type_name = model.config.model_type gptq_model_config = model_config_map[model_type_name] - layers = get_module_by_name_prefix(model.model, gptq_model_config.layer_blocks) + layers = get_module_by_name_prefix(model, gptq_model_config.layer_blocks) for layer in layers: diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index a5a55702ade0..3b5a90c43ec8 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -1,15 +1,28 @@ +import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch +import torch.distributed as dist import torch.nn as nn from transformers import BloomForCausalLM, LlamaForCausalLM from transformers.generation import GenerationConfig from transformers.generation.stopping_criteria import StoppingCriteriaList from transformers.tokenization_utils_base import BatchEncoding +from colossalai.gptq.cai_gptq import CaiQuantLinear +from colossalai.gptq.gptq_tp import replace_autogptq_linear from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.auto_policy import get_autopolicy +HAS_GPTQ_CUDA = False +try: + from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() + HAS_GPTQ_CUDA = True +except ImportError: + warnings.warn('CUDA gptq is not installed') + HAS_GPTQ_CUDA = False + from .batch_infer_state import BatchInferState from .kvcache_manager import MemoryManager @@ -46,6 +59,7 @@ def __init__(self, max_input_len: int, max_output_len: int, dtype: torch.dtype = torch.float16, + gptq: bool = False, device: str = 'cuda') -> None: self.max_batch_size = max_batch_size self.max_input_len = max_input_len @@ -66,6 +80,14 @@ def __init__(self, self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.cache_manager = None + self.gptq = gptq + self.max_dq_buffer_size = 1 + self.max_inner_outer_dim = 1 + self.gptq_temp_state_buffer = None + self.gptq_temp_dq_buffer = None + self.bits = 4 + self.use_act_order = False + self.shard_config = shard_config self.model = None # optimize the original model by sharding with ShardFormer @@ -78,6 +100,38 @@ def _init_manager(self) -> None: self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num) + def _post_init_gptq_buffer(self, model: nn.Module) -> None: + + for name, submodule in model.named_modules(): + if isinstance(submodule, CaiQuantLinear): + self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8) + + if self.use_act_order: + self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures, + submodule.outfeatures) + + max_input_len = 1 + if self.use_act_order: + max_input_len = self.max_input_len + # The temp_state buffer is required to reorder X in the act-order case. + # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. + self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim), + dtype=torch.float16, + device=torch.cuda.current_device()) + self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size), + dtype=torch.float16, + device=torch.cuda.current_device()) + + gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, + self.gptq_temp_dq_buffer) + # Using the default from exllama repo here. + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + torch.cuda.empty_cache() + def _optimize_model(self, model: nn.Module) -> None: """ Optimize the original model by sharding with ShardFormer. @@ -124,6 +178,14 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: model_name = model.__class__.__name__ assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." policy = get_autopolicy(model, inference_only=True) + if not hasattr(policy, "gptq"): + setattr(policy, "gptq", False) + if self.gptq: + setattr(policy, "gptq", True) + tp_rank = dist.get_rank(self.shard_config.tensor_parallel_process_group) + replace_autogptq_linear(model, tp_size=self.tp_size, tp_rank=tp_rank) + if HAS_GPTQ_CUDA and self.bits == 4: + self._post_init_gptq_buffer(model) self.model, _ = shardformer.optimize(model, policy) self.model = self.model.cuda() diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py index 63791fe27284..f811e7775094 100644 --- a/colossalai/inference/tensor_parallel/policies/bloom.py +++ b/colossalai/inference/tensor_parallel/policies/bloom.py @@ -3,6 +3,9 @@ import torch from torch.nn import LayerNorm +import colossalai.shardformer.layer as col_nn +from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy from ..modeling.bloom import BloomInferenceForwards @@ -33,7 +36,23 @@ def __init__(self) -> None: def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel - policy = super().module_policy() + policy = {} + if not self.gptq: + policy = super().module_policy() + else: + policy[BloomModel] = ModulePolicyDescription( + attribute_replacement={ + "num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + method_replacement={ + "build_alibi_tensor": build_bloom_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group) + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ]) # NOTE set inference mode to shard config self.shard_config._infer() diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index e819f2a8810c..de41f9c1353a 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -1,14 +1,13 @@ from functools import partial + import torch -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaModel, - LlamaRMSNorm -) +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm +from colossalai.shardformer.layer import VocabParallelEmbedding1D +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription # import colossalai from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy + from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward try: @@ -18,23 +17,34 @@ print("you should install triton from https://github.com/openai/triton") HAS_TRITON_RMSNORM = False - + def get_triton_rmsnorm_forward(): if HAS_TRITON_RMSNORM: + def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) - + return _triton_rmsnorm_forward else: return None - + + class LlamaModelInferPolicy(LlamaForCausalLMPolicy): def __init__(self) -> None: super().__init__() def module_policy(self): - policy = super().module_policy() + policy = {} + if not self.gptq: + policy = super().module_policy() + else: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ), + policy=policy, + target_key=LlamaModel) self.shard_config._infer() infer_forward = LlamaInferenceForwards.llama_model_forward @@ -59,12 +69,11 @@ def module_policy(self): else: # NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123 infer_forward = get_llama_vllm_rmsnorm_forward() - + if infer_forward is not None: method_replacement = {'forward': partial(infer_forward)} self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=LlamaRMSNorm) + policy=policy, + target_key=LlamaRMSNorm) return policy - diff --git a/examples/inference/gptq_bloom.py b/examples/inference/gptq_bloom.py new file mode 100644 index 000000000000..47ad9e8114b6 --- /dev/null +++ b/examples/inference/gptq_bloom.py @@ -0,0 +1,131 @@ +import argparse +import logging +import os +import time + +import torch +from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig +from auto_gptq.nn_modules.qlinear import GeneralQuantLinear +from transformers import AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, LlamaForCausalLM, LlamaTokenizer + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' + + +def print_perf_stats(latency_set, config, bs, warmup=3): + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = getattr(config, "num_layers", config.num_hidden_layers) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 # float16 + + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) + print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs)) + + +def bench_bloom(args): + # model_path = args.path + max_batch_size = args.batch_size + max_input_len = args.input_len + max_output_len = args.output_len + + pretrained_model_dir = "/home/lczyh/data3/models/bloom-7b1" + quantized_model_dir = "/home/lcxk/data3/test_gptq_llama/bloom-7b-no-act-4bit" + + tokenizer = BloomTokenizerFast.from_pretrained(pretrained_model_dir) + tokenizer.pad_token = tokenizer.eos_token + + # load quantized model to the first GPU + model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, + device=torch.cuda.current_device(), + inject_fused_attention=False) + + model = model.half() + + model_config = model.config + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len, gptq=True) + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + + # inference with model.generate + print("input is:", "auto-gptq is") + print( + tokenizer.decode( + infer_engine.generate(tokenizer("auto-gptq is", return_tensors="pt").to('cuda'), max_new_tokens=128)[0])) + print("input is:", "today is") + print( + tokenizer.decode( + infer_engine.generate(tokenizer("today is ", return_tensors="pt").to('cuda'), max_new_tokens=128)[0])) + + input_tokens = { + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), + "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') + } + + # init TPInferEngine and shard the original model + # To benchmark torch original, comment out the line of optimizing model + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) + + # prepare data for generation + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + input_tokens = { + "input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)), + "attention_mask": torch.ones((max_batch_size, max_input_len)) + } + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) + # print(f" input_tokens[{t}].shape: {input_tokens[t].shape}") + + iters = 10 + times = [] + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s") + times.append((end - start) / (out_len - max_input_len)) + + print_perf_stats(times, model_config, max_batch_size) + + +def check_bloom(rank, world_size, port, args): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + bench_bloom(args) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bloom(args): + spawn(check_bloom, args.tp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # parser.add_argument('-p', '--path', type=str, help='Model path', required=True) + parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') + parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') + parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') + parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') + + args = parser.parse_args() + + test_bloom(args) diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py index ae398740dcdb..2d20dde8ff89 100644 --- a/examples/inference/gptq_llama.py +++ b/examples/inference/gptq_llama.py @@ -1,71 +1,208 @@ +import argparse import logging +import os +import time import torch from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig from auto_gptq.nn_modules.qlinear import GeneralQuantLinear from torch import distributed as dist +from torch.profiler import ProfilerActivity, profile, record_function from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, TextGenerationPipeline +import colossalai from colossalai.gptq import CaiQuantLinear from colossalai.gptq.gptq_tp import replace_autogptq_linear +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn -logging.basicConfig(format="%(asctime)s %(levelname)s [%(name)s] %(message)s", - level=logging.INFO, - datefmt="%Y-%m-%d %H:%M:%S") -dist.init_process_group(backend="nccl") -pretrained_model_dir = "/data/scratch/llama-7b-hf" -# quantized_model_dir = "llama-7b-with-act-4bit" -quantized_model_dir = "/home/lcxk/data3/test_gptq_llama/llama-7b-no-act-4bit" -rank = dist.get_rank() -world_size = dist.get_world_size() -# rank = 1 -# world_size=2 -torch.cuda.set_device(rank) -print("world size {0} rank {1} deivce {2}".format(world_size, rank, torch.cuda.current_device())) -tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) -examples = [ - tokenizer( - "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm.") -] - -# quantize_config = BaseQuantizeConfig( -# bits=4, # quantize model to 4-bit -# group_size=128, # it is recommended to set the value to 128 -# desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad -# ) - -# # load un-quantized model, by default, the model will always be loaded into CPU memory -# model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) - -# # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask" -# model.quantize(examples) - -# # save quantized model -# model.save_quantized(quantized_model_dir) - -# # save quantized model using safetensors -# model.save_quantized(quantized_model_dir, use_safetensors=True) - -# load quantized model to the first GPU -model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, - device=torch.cuda.current_device(), - inject_fused_attention=False) - -replace_autogptq_linear(model, tp_size=world_size, tp_rank=rank) - -# if rank == 0: -# print(model.config) -# print(model) -# download quantized model from Hugging Face Hub and load to the first GPU -# model = AutoGPTQForCausalLM.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False) - -# inference with model.generate -print("input is:", "auto-gptq is") -print( - tokenizer.decode( - model.generate(**tokenizer("auto-gptq is", return_tensors="pt").to(model.device), max_new_tokens=128)[0])) -dist.barrier() -print("input is:", "today is") -print( - tokenizer.decode( - model.generate(**tokenizer("today is ", return_tensors="pt").to(model.device), max_new_tokens=128)[0])) +# logging.basicConfig(format="%(asctime)s %(levelname)s [%(name)s] %(message)s", +# level=logging.INFO, +# datefmt="%Y-%m-%d %H:%M:%S") +# dist.init_process_group(backend="nccl") +# pretrained_model_dir = "/data/scratch/llama-7b-hf" +# # quantized_model_dir = "llama-7b-with-act-4bit" +# quantized_model_dir = "/home/lcxk/data3/test_gptq_llama/llama-7b-no-act-4bit" +# rank = dist.get_rank() +# world_size = dist.get_world_size() +# # rank = 1 +# # world_size=2 +# torch.cuda.set_device(rank) +# print("world size {0} rank {1} deivce {2}".format(world_size, rank, torch.cuda.current_device())) +# tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) +# examples = [ +# tokenizer( +# "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm.") +# ] + +# # quantize_config = BaseQuantizeConfig( +# # bits=4, # quantize model to 4-bit +# # group_size=128, # it is recommended to set the value to 128 +# # desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad +# # ) + +# # # load un-quantized model, by default, the model will always be loaded into CPU memory +# # model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) + +# # # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask" +# # model.quantize(examples) + +# # # save quantized model +# # model.save_quantized(quantized_model_dir) + +# # # save quantized model using safetensors +# # model.save_quantized(quantized_model_dir, use_safetensors=True) + +# # load quantized model to the first GPU +# model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, +# device=torch.cuda.current_device(), +# inject_fused_attention=False) + +# replace_autogptq_linear(model, tp_size=world_size, tp_rank=rank) + +# # if rank == 0: +# # print(model.config) +# # print(model) +# # download quantized model from Hugging Face Hub and load to the first GPU +# # model = AutoGPTQForCausalLM.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False) + +# # inference with model.generate +# print("input is:", "auto-gptq is") +# print( +# tokenizer.decode( +# model.generate(**tokenizer("auto-gptq is", return_tensors="pt").to(model.device), max_new_tokens=128)[0])) +# dist.barrier() +# print("input is:", "today is") +# print( +# tokenizer.decode( +# model.generate(**tokenizer("today is ", return_tensors="pt").to(model.device), max_new_tokens=128)[0])) + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' + + +def init_to_get_rotary(self, base=10000): + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / + self.config.head_dim_)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + return + + +def print_perf_stats(latency_set, config, bs, warmup=3): + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = getattr(config, "num_layers", config.num_hidden_layers) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 + + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) + print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs)) + + +def run_llama_test(args): + # llama_model_path = args.path + max_batch_size = args.batch_size + max_input_len = args.input_len + max_output_len = args.output_len + + pretrained_model_dir = "/data/scratch/llama-7b-hf" + # quantized_model_dir = "llama-7b-with-act-4bit" + quantized_model_dir = "/home/lcxk/data3/test_gptq_llama/llama-7b-no-act-4bit" + + tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) + tokenizer.pad_token_id = tokenizer.eos_token_id + + # load quantized model to the first GPU + model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, + device=torch.cuda.current_device(), + inject_fused_attention=False) + + init_to_get_rotary(model.model.model, base=10000) + + model_config = model.config + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len, gptq=True) + + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + + # inference with model.generate + print("input is:", "auto-gptq is") + print( + tokenizer.decode( + infer_engine.generate(tokenizer("auto-gptq is", return_tensors="pt").to('cuda'), max_new_tokens=128)[0])) + dist.barrier() + print("input is:", "today is") + print( + tokenizer.decode( + infer_engine.generate(tokenizer("today is ", return_tensors="pt").to('cuda'), max_new_tokens=128)[0])) + + input_tokens = { + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), + "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') + } + + iters = 10 + times = [] + + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s") + times.append((end - start) / (out_len - max_input_len)) + + print("outputs, ", len(outputs)) + print_perf_stats(times, model_config, max_batch_size) + + +def check_llama(rank, world_size, port, args): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_test(args) + + +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(args): + spawn(check_llama, args.tp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # parser.add_argument('-p', '--path', type=str, help='Model path', required=True) + parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') + parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') + parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') + parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') + + args = parser.parse_args() + + test_llama(args) diff --git a/tests/test_gptq/test_gptq_linear.py b/tests/test_gptq/test_gptq_linear.py index 0d0343a5c407..889732786fd8 100644 --- a/tests/test_gptq/test_gptq_linear.py +++ b/tests/test_gptq/test_gptq_linear.py @@ -28,6 +28,17 @@ HAS_AUTO_GPTQ = False print("please install triton from https://github.com/PanQiWei/AutoGPTQ") +import warnings + +HAS_GPTQ_CUDA = False +try: + from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() + HAS_GPTQ_CUDA = True +except ImportError: + warnings.warn('CUDA gptq is not installed') + HAS_GPTQ_CUDA = False + TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') wbits = 4 @@ -231,6 +242,43 @@ def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize return qweight, qscales, qzeros +max_inner_outer_dim = 1 +max_input_len = 1 +max_dq_buffer_size = 1 +gptq_temp_dq_buffer = None +gptq_temp_state_buffer = None + + +def init_buffer(cai_linear, use_act_order=False): + global max_dq_buffer_size + global max_input_len + global max_dq_buffer_size + global max_inner_outer_dim + global gptq_temp_dq_buffer + global gptq_temp_state_buffer + + max_dq_buffer_size = max(max_dq_buffer_size, cai_linear.qweight.numel() * 8) + + if use_act_order: + max_inner_outer_dim = max(max_inner_outer_dim, cai_linear.infeatures, cai_linear.outfeatures) + + if use_act_order: + max_input_len = 4096 + # The temp_state buffer is required to reorder X in the act-order case. + # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. + gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim), + dtype=torch.float16, + device=torch.cuda.current_device()) + gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()) + + gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer) + # Using the default from exllama repo here. + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + @pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ, reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq") def test_gptq_linear(): @@ -279,6 +327,7 @@ def test_gptq_linear(): cai_linear.to("cuda") cai_linear.pack(linear.linear, scale, zero, g_idx) cai_linear.to("cuda") + init_buffer(cai_linear) gptq_model = model_pack(linear, quantizers, wbits, groupsize) gptq_model.to(torch.cuda.current_device()) From 5b8604c7a2dfa14a6ddf8825770f0cfbf0776c6b Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 13 Sep 2023 15:34:48 +0800 Subject: [PATCH 2/7] add init bits --- colossalai/inference/tensor_parallel/engine.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 3b5a90c43ec8..2699389b2646 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -85,7 +85,7 @@ def __init__(self, self.max_inner_outer_dim = 1 self.gptq_temp_state_buffer = None self.gptq_temp_dq_buffer = None - self.bits = 4 + self.bits = -1 self.use_act_order = False self.shard_config = shard_config @@ -109,6 +109,9 @@ def _post_init_gptq_buffer(self, model: nn.Module) -> None: if self.use_act_order: self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures) + self.bits = submodule.bits + if not (HAS_GPTQ_CUDA and self.bits == 4): + return max_input_len = 1 if self.use_act_order: @@ -184,8 +187,7 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: setattr(policy, "gptq", True) tp_rank = dist.get_rank(self.shard_config.tensor_parallel_process_group) replace_autogptq_linear(model, tp_size=self.tp_size, tp_rank=tp_rank) - if HAS_GPTQ_CUDA and self.bits == 4: - self._post_init_gptq_buffer(model) + self._post_init_gptq_buffer(model) self.model, _ = shardformer.optimize(model, policy) self.model = self.model.cuda() From 5e8ab734c9317a75b0198b2b1acc23881acf2ddd Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 13 Sep 2023 15:38:34 +0800 Subject: [PATCH 3/7] delete useless code --- examples/inference/gptq_llama.py | 61 -------------------------------- 1 file changed, 61 deletions(-) diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py index 2d20dde8ff89..e534a746e3bf 100644 --- a/examples/inference/gptq_llama.py +++ b/examples/inference/gptq_llama.py @@ -18,67 +18,6 @@ from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn -# logging.basicConfig(format="%(asctime)s %(levelname)s [%(name)s] %(message)s", -# level=logging.INFO, -# datefmt="%Y-%m-%d %H:%M:%S") -# dist.init_process_group(backend="nccl") -# pretrained_model_dir = "/data/scratch/llama-7b-hf" -# # quantized_model_dir = "llama-7b-with-act-4bit" -# quantized_model_dir = "/home/lcxk/data3/test_gptq_llama/llama-7b-no-act-4bit" -# rank = dist.get_rank() -# world_size = dist.get_world_size() -# # rank = 1 -# # world_size=2 -# torch.cuda.set_device(rank) -# print("world size {0} rank {1} deivce {2}".format(world_size, rank, torch.cuda.current_device())) -# tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) -# examples = [ -# tokenizer( -# "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm.") -# ] - -# # quantize_config = BaseQuantizeConfig( -# # bits=4, # quantize model to 4-bit -# # group_size=128, # it is recommended to set the value to 128 -# # desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad -# # ) - -# # # load un-quantized model, by default, the model will always be loaded into CPU memory -# # model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) - -# # # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask" -# # model.quantize(examples) - -# # # save quantized model -# # model.save_quantized(quantized_model_dir) - -# # # save quantized model using safetensors -# # model.save_quantized(quantized_model_dir, use_safetensors=True) - -# # load quantized model to the first GPU -# model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, -# device=torch.cuda.current_device(), -# inject_fused_attention=False) - -# replace_autogptq_linear(model, tp_size=world_size, tp_rank=rank) - -# # if rank == 0: -# # print(model.config) -# # print(model) -# # download quantized model from Hugging Face Hub and load to the first GPU -# # model = AutoGPTQForCausalLM.from_quantized(repo_id, device="cuda:0", use_safetensors=True, use_triton=False) - -# # inference with model.generate -# print("input is:", "auto-gptq is") -# print( -# tokenizer.decode( -# model.generate(**tokenizer("auto-gptq is", return_tensors="pt").to(model.device), max_new_tokens=128)[0])) -# dist.barrier() -# print("input is:", "today is") -# print( -# tokenizer.decode( -# model.generate(**tokenizer("today is ", return_tensors="pt").to(model.device), max_new_tokens=128)[0])) - os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' From 0212cb2cfc2b61357e37db863fe0558bbcfeb09d Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 13 Sep 2023 15:43:06 +0800 Subject: [PATCH 4/7] add model path --- examples/inference/gptq_bloom.py | 10 +++++----- examples/inference/gptq_llama.py | 10 ++++------ 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/examples/inference/gptq_bloom.py b/examples/inference/gptq_bloom.py index 47ad9e8114b6..8d4e2d12da11 100644 --- a/examples/inference/gptq_bloom.py +++ b/examples/inference/gptq_bloom.py @@ -37,14 +37,13 @@ def print_perf_stats(latency_set, config, bs, warmup=3): def bench_bloom(args): - # model_path = args.path + + pretrained_model_dir = args.path + quantized_model_dir = args.quantized_path max_batch_size = args.batch_size max_input_len = args.input_len max_output_len = args.output_len - pretrained_model_dir = "/home/lczyh/data3/models/bloom-7b1" - quantized_model_dir = "/home/lcxk/data3/test_gptq_llama/bloom-7b-no-act-4bit" - tokenizer = BloomTokenizerFast.from_pretrained(pretrained_model_dir) tokenizer.pad_token = tokenizer.eos_token @@ -120,7 +119,8 @@ def test_bloom(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - # parser.add_argument('-p', '--path', type=str, help='Model path', required=True) + parser.add_argument('-p', '--path', type=str, help='Model path', required=True) + parser.add_argument('-p', '--quantized_path', type=str, help='Model path', required=True) parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py index e534a746e3bf..2908e556fdb7 100644 --- a/examples/inference/gptq_llama.py +++ b/examples/inference/gptq_llama.py @@ -64,15 +64,12 @@ def print_perf_stats(latency_set, config, bs, warmup=3): def run_llama_test(args): - # llama_model_path = args.path + pretrained_model_dir = args.path + quantized_model_dir = args.quantized_path max_batch_size = args.batch_size max_input_len = args.input_len max_output_len = args.output_len - pretrained_model_dir = "/data/scratch/llama-7b-hf" - # quantized_model_dir = "llama-7b-with-act-4bit" - quantized_model_dir = "/home/lcxk/data3/test_gptq_llama/llama-7b-no-act-4bit" - tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) tokenizer.pad_token_id = tokenizer.eos_token_id @@ -136,7 +133,8 @@ def test_llama(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - # parser.add_argument('-p', '--path', type=str, help='Model path', required=True) + parser.add_argument('-p', '--path', type=str, help='Model path', required=True) + parser.add_argument('-p', '--quantized_path', type=str, help='Model path', required=True) parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') From 094649fb45133133ff207f6d1e3f5d1cb8944813 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 14 Sep 2023 13:42:10 +0800 Subject: [PATCH 5/7] delete usless print and update test --- examples/inference/gptq_bloom.py | 12 +- examples/inference/gptq_llama.py | 14 +- tests/test_gptq/test_gptq_linear.py | 307 ++++------------------------ 3 files changed, 46 insertions(+), 287 deletions(-) diff --git a/examples/inference/gptq_bloom.py b/examples/inference/gptq_bloom.py index 8d4e2d12da11..22d591764465 100644 --- a/examples/inference/gptq_bloom.py +++ b/examples/inference/gptq_bloom.py @@ -59,16 +59,6 @@ def bench_bloom(args): infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len, gptq=True) generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) - # inference with model.generate - print("input is:", "auto-gptq is") - print( - tokenizer.decode( - infer_engine.generate(tokenizer("auto-gptq is", return_tensors="pt").to('cuda'), max_new_tokens=128)[0])) - print("input is:", "today is") - print( - tokenizer.decode( - infer_engine.generate(tokenizer("today is ", return_tensors="pt").to('cuda'), max_new_tokens=128)[0])) - input_tokens = { "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') @@ -120,7 +110,7 @@ def test_bloom(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-p', '--path', type=str, help='Model path', required=True) - parser.add_argument('-p', '--quantized_path', type=str, help='Model path', required=True) + parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True) parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py index 2908e556fdb7..7c28d359267c 100644 --- a/examples/inference/gptq_llama.py +++ b/examples/inference/gptq_llama.py @@ -86,17 +86,6 @@ def run_llama_test(args): generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) - # inference with model.generate - print("input is:", "auto-gptq is") - print( - tokenizer.decode( - infer_engine.generate(tokenizer("auto-gptq is", return_tensors="pt").to('cuda'), max_new_tokens=128)[0])) - dist.barrier() - print("input is:", "today is") - print( - tokenizer.decode( - infer_engine.generate(tokenizer("today is ", return_tensors="pt").to('cuda'), max_new_tokens=128)[0])) - input_tokens = { "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') @@ -115,7 +104,6 @@ def run_llama_test(args): print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s") times.append((end - start) / (out_len - max_input_len)) - print("outputs, ", len(outputs)) print_perf_stats(times, model_config, max_batch_size) @@ -134,7 +122,7 @@ def test_llama(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-p', '--path', type=str, help='Model path', required=True) - parser.add_argument('-p', '--quantized_path', type=str, help='Model path', required=True) + parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True) parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') diff --git a/tests/test_gptq/test_gptq_linear.py b/tests/test_gptq/test_gptq_linear.py index 889732786fd8..a75dac1568e5 100644 --- a/tests/test_gptq/test_gptq_linear.py +++ b/tests/test_gptq/test_gptq_linear.py @@ -17,10 +17,14 @@ print("please install triton from https://github.com/openai/triton") try: + from auto_gptq import AutoGPTQForCausalLM, exllama_set_max_input_length from auto_gptq.modeling._utils import autogptq_post_init, find_layers, pack_model + from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear from auto_gptq.quantization import GPTQ from auto_gptq.quantization.quantizer import Quantizer + from auto_gptq.utils.import_utils import dynamically_import_QuantLinear + from exllama_kernels import prepare_buffers, set_tuning_params from colossalai.gptq import CaiGPTQLinearOp, CaiQuantLinear HAS_AUTO_GPTQ = True @@ -41,207 +45,6 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') -wbits = 4 -trits = False -nsamples = 1 -percdamp = .01 -groupsize = 128 -act_order = False -sym = False - - -class MLinear(nn.Module): - - def __init__(self, infeature, outfeature): - super(MLinear, self).__init__() - self.linear = torch.nn.Linear(infeature, outfeature, dtype=torch.float16) - - def forward(self, x): - out = self.linear(x) - return out - - -@torch.no_grad() -def model_quant(model, inps, dev): - print('Starting ...') - layers = [model] - layers[0] = layers[0].to(dev) - - dtype = next(iter(model.parameters())).dtype - cache = {'i': 0} - - class Catcher(nn.Module): - - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self, inp, **kwargs): - inps[cache['i']] = inp - cache['i'] += 1 - raise ValueError - - layers[0] = Catcher(layers[0]) - # for batch in inps: - try: - model(inps.to(dev)) - except ValueError: - pass - layers[0] = layers[0].module - - outs = torch.zeros(inps.shape[0], layers[0].linear.weight.shape[0]) - - print('Ready.') - - quantizers = {} - for i in range(len(layers)): - layer = layers[i].to(dev) - subset = find_layers(layer) - gptq = {} - for name in subset: - gptq[name] = GPTQ(subset[name]) - # gptq[name].quantizer = Quantizer() - gptq[name].quantizer.configure(wbits, perchannel=True, sym=sym, mse=False, trits=trits) - - def add_batch(name): - - def tmp(_, inp, out): - gptq[name].add_batch(inp[0].data, out.data) - - return tmp - - handles = [] - for name in subset: - handles.append(subset[name].register_forward_hook(add_batch(name))) - - for j in range(nsamples): - outs[j] = layer(inps[j].unsqueeze(0))[0] - - for h in handles: - h.remove() - for name in subset: - print(f'Quantizing {name} in layer {i+1}/{len(layers)}...') - scale, zero, g_idx = gptq[name].fasterquant(percdamp=percdamp, group_size=groupsize, actorder=act_order) - # quantizers['%s' % (name)] = (gptq[name].quantizer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) - quantizers['%s' % (name)] = (gptq[name].layer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()) - - gptq[name].free() - for j in range(nsamples): - layer = layer.to(dev) - outs[j] = layer(inps[j].unsqueeze(0))[0] - - layers[i] = layer.cpu() - del layer - del gptq - torch.cuda.empty_cache() - - inps, outs = outs, inps - - return quantizers - - -def model_pack(model, quantizers, wbits, groupsize): - pack_model(model, quantizers, wbits, groupsize) - return model - - -def cai_linear_pack(linear, scales, zeros, out_qweight, out_qscales, out_qzeros, qg_idx, infeatures, groupsize, bits): - g_idx = qg_idx.clone() if qg_idx is not None else torch.tensor([i // groupsize for i in range(infeatures)], - dtype=torch.int32) - - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() - scale_zeros = zeros * scales - half_scales = scales.clone().half() - # print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape) - - out_qscales.data.copy_(scales) - - # wn = 16 - # pbits = 64 - # ptype = torch.int64 - # unsign_type = np.uint64 - # sign_type = np.int64 - - wn = 8 - pbits = 32 - ptype = torch.int32 - unsign_type = np.uint32 - sign_type = np.int32 - - intweight = [] - for idx in range(infeatures): - intweight.append( - torch.round( - (linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, None]) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(unsign_type) - qweight = np.zeros((intweight.shape[0] // pbits * bits, intweight.shape[1]), dtype=unsign_type) - - i = 0 - row = 0 - # print("weight shape ", intweight.shape, qweight.shape, out_qweight.shape, bits) - # print("weight shape ", intweight[0].shape, qweight[0].shape, out_qweight[0].shape) - # print("weight value ", intweight[0], qweight[0]) - - while row < qweight.shape[0]: - if bits in [2, 4, 8]: - for j in range(i, i + (pbits // bits)): - qweight[row] |= intweight[j] << (bits * (j - i)) - i += pbits // bits - row += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - qweight = qweight.astype(sign_type) - qweight1 = torch.from_numpy(qweight) - qweight1 = qweight1.contiguous().to("cuda") - out_qweight.data.copy_(qweight1) - - qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * bits), dtype=unsign_type) - zeros -= 1 - zeros = zeros.numpy().astype(unsign_type) - i = 0 - col = 0 - while col < qzeros.shape[1]: - if bits in [2, 4, 8]: - for j in range(i, i + (pbits // bits)): - qzeros[:, col] |= zeros[:, j] << (bits * (j - i)) - i += pbits // bits - col += 1 - else: - raise NotImplementedError("Only 2,4,8 bits are supported.") - qzeros = qzeros.astype(sign_type) - qzeros = torch.from_numpy(qzeros) - qzeros = qzeros.to("cuda") - out_qzeros.data.copy_(qzeros) - - return out_qweight, out_qscales, out_qzeros - - -def get_model_param(model, quantizers): - layers = find_layers(model) - layers = {n: layers[n] for n in quantizers} - with torch.no_grad(): - for name in layers: - _, scale, zero, g_idx = quantizers[name] - - return scale, zero, g_idx - - -def model_cai_pack(model, quantizers, qweight, qscales, qzeros, wbits, groupsize): - layers = find_layers(model) - layers = {n: layers[n] for n in quantizers} - with torch.no_grad(): - for name in layers: - _, scale, zero, g_idx = quantizers[name] - qweight, qscales, qzeros = cai_linear_pack(layers[name], scale, zero, qweight, qscales, qzeros, g_idx, - layers[name].weight.shape[-1], groupsize, wbits) - - # print("cai pack", layers) - return qweight, qscales, qzeros - - max_inner_outer_dim = 1 max_input_len = 1 max_dq_buffer_size = 1 @@ -283,91 +86,69 @@ def init_buffer(cai_linear, use_act_order=False): reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq") def test_gptq_linear(): - infeature = 5120 - outfeature = 5120 + infeature = 1024 + outfeature = 1024 + group_size = 128 + wbits = 4 - weight = torch.randn(outfeature, infeature).to(torch.float16).to(torch.cuda.current_device()) - bias = torch.zeros(outfeature).to(torch.float16).to(torch.cuda.current_device()) - # wn = 16 - # ptype = torch.int64 + inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device()) + batch_inps = torch.randn(1, 16, infeature).to(torch.float16).to(torch.cuda.current_device()) - wn = 8 - ptype = torch.int32 + device = torch.device("cuda:0") - qweight = torch.zeros(infeature // wn, outfeature, dtype=ptype, device=torch.cuda.current_device()).contiguous() - qscales = torch.zeros(infeature // groupsize, outfeature, dtype=torch.float16, - device=torch.cuda.current_device()).contiguous() - qzeros = torch.zeros(infeature // groupsize, outfeature // wn, dtype=ptype, - device=torch.cuda.current_device()).contiguous() + linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=wbits) - act_func = nn.SiLU() - inps = torch.ones(1, 1, infeature).to(torch.float16).to(torch.cuda.current_device()) - batch_inps = torch.randn(1, 4096, infeature).to(torch.float16).to(torch.cuda.current_device()) + linear = linear_class( + bits=4, + group_size=group_size, + infeatures=infeature, + outfeatures=outfeature, + bias=False, + ) - linear = MLinear(infeature, outfeature) - linear.to(torch.cuda.current_device()) + torch.manual_seed(42) - with torch.no_grad(): - linear.linear.weight.data.copy_(weight) - linear.linear.bias.data.copy_(bias) + linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32) + linear.scales = linear.scales + 0.002 - with torch.no_grad(): - torch_out = linear(inps) - batch_torch_out = linear(batch_inps) - # torch_out = act_func(torch_out) - # batch_torch_out = act_func(batch_torch_out) + linear = linear.to(device) - # linear.to("cuda") - quantizers = model_quant(linear, inps, torch.cuda.current_device()) - # qweight, qscales, qzeros = model_cai_pack(linear, quantizers, qweight, qscales, qzeros, wbits, groupsize) + cai_linear = CaiQuantLinear(wbits, group_size, infeature, outfeature, True) + cai_linear.qweight.data.copy_(linear.qweight) + cai_linear.scales = cai_linear.scales + 0.002 + cai_linear = cai_linear.to(device) - scale, zero, g_idx = get_model_param(linear, quantizers) - cai_linear = CaiQuantLinear(wbits, groupsize, infeature, outfeature, True) + linear = autogptq_post_init(linear, use_act_order=False) - cai_linear.to("cuda") - cai_linear.pack(linear.linear, scale, zero, g_idx) - cai_linear.to("cuda") - init_buffer(cai_linear) + max_inner_outer_dim = max(infeature, outfeature) + max_dq_buffer_size = linear.infeatures * linear.outfeatures + max_input_len = 2048 + buffers = { + "temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device), + "temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device) + } - gptq_model = model_pack(linear, quantizers, wbits, groupsize) - gptq_model.to(torch.cuda.current_device()) - gptq_model = autogptq_post_init(gptq_model, False) + prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"]) + + # Using the default from exllama repo here. + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) with torch.no_grad(): - gptq_out = gptq_model(inps) - batch_gptq_out = gptq_model(batch_inps) + gptq_out = linear(inps) + batch_gptq_out = linear(batch_inps) torch.cuda.synchronize() cai_out = cai_linear(inps) torch.cuda.synchronize() batch_cai_out = cai_linear(batch_inps) torch.cuda.synchronize() - # batch_gptq_out = act_func(batch_gptq_out) - # gptq_out = act_func(gptq_out) assert torch.allclose(cai_out, gptq_out, rtol=1e-01, atol=1e-01) assert torch.allclose(batch_cai_out, batch_gptq_out, rtol=1e-01, atol=1e-01) - # mean_diff = torch.mean(torch.abs(cai_out - gptq_out)) - # max_diff = torch.max(torch.abs(cai_out - gptq_out)) - # print("cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - # mean_diff = torch.mean(torch.abs(torch_out - gptq_out)) - # max_diff = torch.max(torch.abs(torch_out - gptq_out)) - # print("torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - # mean_diff = torch.mean(torch.abs(torch_out - cai_out)) - # max_diff = torch.max(torch.abs(torch_out - cai_out)) - # print("torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - - # mean_diff = torch.mean(torch.abs(batch_cai_out - batch_gptq_out)) - # max_diff = torch.max(torch.abs(batch_cai_out - batch_gptq_out)) - # print("batch cai vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - # mean_diff = torch.mean(torch.abs(batch_torch_out - batch_gptq_out)) - # max_diff = torch.max(torch.abs(batch_torch_out - batch_gptq_out)) - # print("batch torch vs gptq: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - # mean_diff = torch.mean(torch.abs(batch_torch_out - batch_cai_out)) - # max_diff = torch.max(torch.abs(batch_torch_out - batch_cai_out)) - # print("batch torch vs cai: mean_diff=%.8f, max_diff=%.8f" % (mean_diff, max_diff)) - if __name__ == "__main__": From a447262c9f86722b24fd18eacbcd9d8a9d50c9c8 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 14 Sep 2023 13:51:27 +0800 Subject: [PATCH 6/7] delete usless import --- tests/test_gptq/test_gptq_linear.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/test_gptq/test_gptq_linear.py b/tests/test_gptq/test_gptq_linear.py index a75dac1568e5..718060c22908 100644 --- a/tests/test_gptq/test_gptq_linear.py +++ b/tests/test_gptq/test_gptq_linear.py @@ -17,16 +17,11 @@ print("please install triton from https://github.com/openai/triton") try: - from auto_gptq import AutoGPTQForCausalLM, exllama_set_max_input_length - from auto_gptq.modeling._utils import autogptq_post_init, find_layers, pack_model - from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear - from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear - from auto_gptq.quantization import GPTQ - from auto_gptq.quantization.quantizer import Quantizer + from auto_gptq.modeling._utils import autogptq_post_init from auto_gptq.utils.import_utils import dynamically_import_QuantLinear from exllama_kernels import prepare_buffers, set_tuning_params - from colossalai.gptq import CaiGPTQLinearOp, CaiQuantLinear + from colossalai.gptq import CaiQuantLinear HAS_AUTO_GPTQ = True except: HAS_AUTO_GPTQ = False From 4fdfdb3483cb5c7d6aa11fd98cd01ba37d762897 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 14 Sep 2023 14:23:35 +0800 Subject: [PATCH 7/7] move option gptq to shard config --- colossalai/inference/tensor_parallel/engine.py | 8 ++------ colossalai/inference/tensor_parallel/policies/bloom.py | 2 +- colossalai/inference/tensor_parallel/policies/llama.py | 2 +- colossalai/shardformer/shard/shard_config.py | 2 +- examples/inference/gptq_bloom.py | 6 ++++-- examples/inference/gptq_llama.py | 6 ++++-- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 2699389b2646..94b44136bebc 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -59,7 +59,6 @@ def __init__(self, max_input_len: int, max_output_len: int, dtype: torch.dtype = torch.float16, - gptq: bool = False, device: str = 'cuda') -> None: self.max_batch_size = max_batch_size self.max_input_len = max_input_len @@ -80,7 +79,6 @@ def __init__(self, self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.cache_manager = None - self.gptq = gptq self.max_dq_buffer_size = 1 self.max_inner_outer_dim = 1 self.gptq_temp_state_buffer = None @@ -181,10 +179,8 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: model_name = model.__class__.__name__ assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." policy = get_autopolicy(model, inference_only=True) - if not hasattr(policy, "gptq"): - setattr(policy, "gptq", False) - if self.gptq: - setattr(policy, "gptq", True) + + if self.shard_config.inference_gptq: tp_rank = dist.get_rank(self.shard_config.tensor_parallel_process_group) replace_autogptq_linear(model, tp_size=self.tp_size, tp_rank=tp_rank) self._post_init_gptq_buffer(model) diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py index f811e7775094..037b0ab85863 100644 --- a/colossalai/inference/tensor_parallel/policies/bloom.py +++ b/colossalai/inference/tensor_parallel/policies/bloom.py @@ -37,7 +37,7 @@ def __init__(self) -> None: def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel policy = {} - if not self.gptq: + if not self.shard_config.inference_gptq: policy = super().module_policy() else: policy[BloomModel] = ModulePolicyDescription( diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index de41f9c1353a..6b6056501ac0 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -36,7 +36,7 @@ def __init__(self) -> None: def module_policy(self): policy = {} - if not self.gptq: + if not self.shard_config.inference_gptq: policy = super().module_policy() else: self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 4380ac30814d..303e0b008041 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -33,9 +33,9 @@ class ShardConfig: enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False inference_only: bool = False + inference_gptq: bool = False enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False - # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/examples/inference/gptq_bloom.py b/examples/inference/gptq_bloom.py index 22d591764465..43e118cc0aa5 100644 --- a/examples/inference/gptq_bloom.py +++ b/examples/inference/gptq_bloom.py @@ -56,7 +56,7 @@ def bench_bloom(args): model_config = model.config shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) - infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len, gptq=True) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) input_tokens = { @@ -66,7 +66,9 @@ def bench_bloom(args): # init TPInferEngine and shard the original model # To benchmark torch original, comment out the line of optimizing model - shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, + inference_only=True, + inference_gptq=True) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) # prepare data for generation diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py index 7c28d359267c..818ae0035e87 100644 --- a/examples/inference/gptq_llama.py +++ b/examples/inference/gptq_llama.py @@ -81,8 +81,10 @@ def run_llama_test(args): init_to_get_rotary(model.model.model, base=10000) model_config = model.config - shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) - infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len, gptq=True) + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, + inference_only=True, + inference_gptq=True) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)