From c1daa2648d2d27cc19bcfa3537b33374f3917d3a Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Mon, 20 Nov 2023 16:30:05 +0800 Subject: [PATCH 1/3] update examples and engine --- colossalai/inference/__init__.py | 4 +- colossalai/inference/engine/__init__.py | 4 +- colossalai/inference/engine/engine.py | 4 +- ...t_llama.py => build_smoothquant_weight.py} | 9 ++- examples/inference/hybrid_gptq_llama.py | 72 ------------------- .../inference/hybrid_smoothquant_llama.py | 69 ------------------ ...hybrid_llama.py => run_llama_inference.py} | 36 +++++++--- tests/test_infer/test_hybrid_bloom.py | 4 +- tests/test_infer/test_hybrid_chatglm2.py | 4 +- tests/test_infer/test_hybrid_llama.py | 4 +- 10 files changed, 41 insertions(+), 169 deletions(-) rename examples/inference/{smoothquant_llama.py => build_smoothquant_weight.py} (90%) delete mode 100644 examples/inference/hybrid_gptq_llama.py delete mode 100644 examples/inference/hybrid_smoothquant_llama.py rename examples/inference/{hybrid_llama.py => run_llama_inference.py} (68%) diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py index 2a2aa8e31069..a95205efaa78 100644 --- a/colossalai/inference/__init__.py +++ b/colossalai/inference/__init__.py @@ -1,4 +1,4 @@ -from .engine import CaiInferEngine +from .engine import InferenceEngine from .engine.policies import BloomModelInferPolicy, ChatGLM2InferPolicy, LlamaModelInferPolicy -__all__ = ["CaiInferEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"] +__all__ = ["InferenceEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"] diff --git a/colossalai/inference/engine/__init__.py b/colossalai/inference/engine/__init__.py index 6377ef817301..6e60da695a22 100644 --- a/colossalai/inference/engine/__init__.py +++ b/colossalai/inference/engine/__init__.py @@ -1,3 +1,3 @@ -from .engine import CaiInferEngine +from .engine import InferenceEngine -__all__ = ["CaiInferEngine"] +__all__ = ["InferenceEngine"] diff --git a/colossalai/inference/engine/engine.py b/colossalai/inference/engine/engine.py index 477a7decd8e3..6181b21d972e 100644 --- a/colossalai/inference/engine/engine.py +++ b/colossalai/inference/engine/engine.py @@ -27,9 +27,9 @@ ] -class CaiInferEngine: +class InferenceEngine: """ - CaiInferEngine is a class that handles the pipeline parallel inference. + InferenceEngine is a class that handles the pipeline parallel inference. Args: tp_size (int): the size of tensor parallelism. diff --git a/examples/inference/smoothquant_llama.py b/examples/inference/build_smoothquant_weight.py similarity index 90% rename from examples/inference/smoothquant_llama.py rename to examples/inference/build_smoothquant_weight.py index ce7a00aa2739..0cb566886f0e 100644 --- a/examples/inference/smoothquant_llama.py +++ b/examples/inference/build_smoothquant_weight.py @@ -29,7 +29,7 @@ def parse_args(): type=str, help="location of the calibration dataset", ) - parser.add_argument("--num-samples", type=int, default=512) + parser.add_argument("--num-samples", type=int, default=10) parser.add_argument("--seq-len", type=int, default=512) args = parser.parse_args() return args @@ -41,13 +41,12 @@ def main(): model_path = args.model_name dataset_path = args.dataset_path output_path = args.output_path - num_samples = 10 - seq_len = 512 + num_samples = args.num_samples + seq_len = args.seq_len model, tokenizer = build_model_and_tokenizer(model_path) if not os.path.exists(dataset_path): - print(f"Cannot find the dataset at {args.dataset_path}") - raise FileNotFoundError + raise FileNotFoundError(f"Cannot find the dataset at {args.dataset_path}") dataset = load_dataset("json", data_files=dataset_path, split="train") model.quantized(tokenizer, dataset, num_samples=num_samples, seq_len=seq_len) diff --git a/examples/inference/hybrid_gptq_llama.py b/examples/inference/hybrid_gptq_llama.py deleted file mode 100644 index 4adbd413f21a..000000000000 --- a/examples/inference/hybrid_gptq_llama.py +++ /dev/null @@ -1,72 +0,0 @@ -import argparse -import os - -import torch -import torch.distributed as dist -from auto_gptq import AutoGPTQForCausalLM - -import colossalai -from colossalai.inference import CaiInferEngine -from colossalai.logging import disable_existing_loggers -from colossalai.testing import spawn - -os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" - - -def run_llama_inference(args): - quantized_model_dir = args.quantized_path - max_batch_size = args.max_batch_size - max_input_len = args.max_input_len - max_output_len = args.max_output_len - micro_batch_size = args.micro_batch_size - # load quantized model to the first GPU - model = AutoGPTQForCausalLM.from_quantized( - quantized_model_dir, inject_fused_attention=False, device=torch.cuda.current_device() - ) - - engine = CaiInferEngine( - tp_size=2, - pp_size=2, - model=model, - max_batch_size=max_batch_size, - max_input_len=max_input_len, - max_output_len=max_output_len, - micro_batch_size=micro_batch_size, - quant="gptq", - ) - - def data_gen(): - input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) - return dict(input_ids=input_ids, attention_mask=attention_mask) - - inputs = data_gen() - for k, v in inputs.items(): - if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: - new_shape = [1] * v.dim() - new_shape[0] = 16 - inputs[k] = v.to("cuda").repeat(*new_shape) - - output = engine.generate(inputs) - if dist.get_rank() == 0: - assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}" - - -def run_gptq_infernece(rank, world_size, port, args): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_llama_inference(args) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True) - parser.add_argument("--tp_size", type=int, default=2, help="Tensor parallel size") - parser.add_argument("--pp_size", type=int, default=2, help="Pipeline parallel size") - parser.add_argument("--max_batch_size", type=int, default=4, help="Maximum batch size") - parser.add_argument("--micro_batch_size", type=int, default=4, help="Micro batch size") - parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length") - parser.add_argument("--max_output_len", type=int, default=32, help="Maximum output length") - args = parser.parse_args() - - spawn(run_gptq_infernece, args.tp_size * args.pp_size, args=args) diff --git a/examples/inference/hybrid_smoothquant_llama.py b/examples/inference/hybrid_smoothquant_llama.py deleted file mode 100644 index 7cb264bc51ad..000000000000 --- a/examples/inference/hybrid_smoothquant_llama.py +++ /dev/null @@ -1,69 +0,0 @@ -import argparse - -import torch -import torch.distributed as dist - -import colossalai -from colossalai.inference import CaiInferEngine -from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM -from colossalai.logging import disable_existing_loggers -from colossalai.testing import spawn - - -@torch.no_grad() -def run_llama_inference(args): - quantized_model_dir = args.quantized_path - max_batch_size = args.max_batch_size - max_input_len = args.max_input_len - max_output_len = args.max_output_len - micro_batch_size = args.micro_batch_size - - def data_gen(): - input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) - attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) - return dict(input_ids=input_ids, attention_mask=attention_mask) - - inputs = data_gen() - for k, v in inputs.items(): - if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: - new_shape = [1] * v.dim() - new_shape[0] = 16 - inputs[k] = v.to("cuda").repeat(*new_shape) - - model = SmoothLlamaForCausalLM.from_quantized(quantized_model_dir, model_basename="llama-7b") - model = model.cuda() - - engine = CaiInferEngine( - tp_size=2, - pp_size=2, - model=model, - max_batch_size=max_batch_size, - max_input_len=max_input_len, - max_output_len=max_output_len, - micro_batch_size=micro_batch_size, - quant="smoothquant", - ) - - output = engine.generate(inputs) - if dist.get_rank() == 0: - assert len(output[0]) == 32, f"{len(output)}, {32}" - - -def run_smoothquant_inference(rank, world_size, port, args): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_llama_inference(args) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True) - parser.add_argument("--tp_size", type=int, default=2, help="Tensor parallel size") - parser.add_argument("--pp_size", type=int, default=2, help="Pipeline parallel size") - parser.add_argument("--max_batch_size", type=int, default=4, help="Maximum batch size") - parser.add_argument("--micro_batch_size", type=int, default=4, help="Micro batch size") - parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length") - parser.add_argument("--max_output_len", type=int, default=32, help="Maximum output length") - - args = parser.parse_args() - spawn(run_smoothquant_inference, args.tp_size * args.pp_size, args=args) diff --git a/examples/inference/hybrid_llama.py b/examples/inference/run_llama_inference.py similarity index 68% rename from examples/inference/hybrid_llama.py rename to examples/inference/run_llama_inference.py index 1bd34afefb79..cb4fe1b74aa4 100644 --- a/examples/inference/hybrid_llama.py +++ b/examples/inference/run_llama_inference.py @@ -3,11 +3,10 @@ import torch import torch.distributed as dist -import transformers from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai -from colossalai.inference import CaiInferEngine +from colossalai.inference import InferenceEngine from colossalai.testing import spawn @@ -21,18 +20,24 @@ def run_inference(args): pp_size = args.pp_size rank = dist.get_rank() - tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) - tokenizer.pad_token_id = tokenizer.unk_token_id - model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) - model = model.half() + if args.quant is None: + tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) + tokenizer.pad_token_id = tokenizer.unk_token_id + model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) + model = model.half() + elif args.quant == "gptq": + from auto_gptq import AutoGPTQForCausalLM - model = transformers.LlamaForCausalLM( - transformers.LlamaConfig( - vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4 + model = AutoGPTQForCausalLM.from_quantized( + llama_model_path, inject_fused_attention=False, device=torch.cuda.current_device() ) - ) + elif args.quant == "smoothquant": + from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM + + model = SmoothLlamaForCausalLM.from_quantized(llama_model_path, model_basename=args.smoothquant_base_name) + model = model.cuda() - engine = CaiInferEngine( + engine = InferenceEngine( tp_size=tp_size, pp_size=pp_size, model=model, @@ -75,6 +80,15 @@ def run_tp_pipeline_inference(rank, world_size, port, args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-p", "--path", type=str, help="Model path", required=True) + parser.add_argument( + "-q", + "--quant", + type=str, + choice=["gptq", "smoothquant"], + default=None, + help="quantization type: 'gptq' or 'smoothquant'", + ) + parser.add_argument("--smoothquant_base_name", type=str, default=None, help="soothquant base name") parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") parser.add_argument("-pp", "--pp_size", type=int, default=1, help="Tensor parallel size") parser.add_argument("-b", "--batch_size", type=int, default=64, help="Maximum batch size") diff --git a/tests/test_infer/test_hybrid_bloom.py b/tests/test_infer/test_hybrid_bloom.py index e344671ec85d..86349d96370e 100644 --- a/tests/test_infer/test_hybrid_bloom.py +++ b/tests/test_infer/test_hybrid_bloom.py @@ -7,7 +7,7 @@ from packaging import version import colossalai -from colossalai.inference import CaiInferEngine +from colossalai.inference import InferenceEngine from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") @@ -36,7 +36,7 @@ def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): transformers.BloomConfig(vocab_size=20000, hidden_size=512, n_head=4, n_layer=4) ) - engine = CaiInferEngine( + engine = InferenceEngine( tp_size=tp_size, pp_size=pp_size, model=model, diff --git a/tests/test_infer/test_hybrid_chatglm2.py b/tests/test_infer/test_hybrid_chatglm2.py index 019b4c0b0d20..c51fddc7dd43 100644 --- a/tests/test_infer/test_hybrid_chatglm2.py +++ b/tests/test_infer/test_hybrid_chatglm2.py @@ -6,7 +6,7 @@ from packaging import version import colossalai -from colossalai.inference import CaiInferEngine +from colossalai.inference import InferenceEngine from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @@ -44,7 +44,7 @@ def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): ) model = ChatGLMForConditionalGeneration(chatglm_config) - engine = CaiInferEngine( + engine = InferenceEngine( tp_size=tp_size, pp_size=pp_size, model=model, diff --git a/tests/test_infer/test_hybrid_llama.py b/tests/test_infer/test_hybrid_llama.py index 05530729c676..47f168ddb77b 100644 --- a/tests/test_infer/test_hybrid_llama.py +++ b/tests/test_infer/test_hybrid_llama.py @@ -7,7 +7,7 @@ from packaging import version import colossalai -from colossalai.inference import CaiInferEngine +from colossalai.inference import InferenceEngine from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") @@ -41,7 +41,7 @@ def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): ) ) - engine = CaiInferEngine( + engine = InferenceEngine( tp_size=tp_size, pp_size=pp_size, model=model, From bfdce45097e80236536e9cfc75263a0fe86f440d Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Mon, 20 Nov 2023 16:41:39 +0800 Subject: [PATCH 2/3] fix choices --- examples/inference/run_llama_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/inference/run_llama_inference.py b/examples/inference/run_llama_inference.py index cb4fe1b74aa4..7560adaad137 100644 --- a/examples/inference/run_llama_inference.py +++ b/examples/inference/run_llama_inference.py @@ -84,7 +84,7 @@ def run_tp_pipeline_inference(rank, world_size, port, args): "-q", "--quant", type=str, - choice=["gptq", "smoothquant"], + choices=["gptq", "smoothquant"], default=None, help="quantization type: 'gptq' or 'smoothquant'", ) From c6f4a7d4c6323e7554a78ec106c564b0459a225b Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Mon, 20 Nov 2023 18:42:36 +0800 Subject: [PATCH 3/3] update example --- colossalai/inference/engine/engine.py | 30 +------------ examples/inference/run_llama_inference.py | 51 +++++++++-------------- requirements/requirements-infer.txt | 1 - 3 files changed, 22 insertions(+), 60 deletions(-) diff --git a/colossalai/inference/engine/engine.py b/colossalai/inference/engine/engine.py index 6181b21d972e..1e9f5ce2b406 100644 --- a/colossalai/inference/engine/engine.py +++ b/colossalai/inference/engine/engine.py @@ -3,7 +3,6 @@ import torch import torch.distributed as dist import torch.nn as nn -from transformers.tokenization_utils_base import BatchEncoding from transformers.utils import logging from colossalai.cluster import ProcessGroupMesh @@ -42,27 +41,6 @@ class InferenceEngine: max_input_len (int): the maximum input length. max_output_len (int): the maximum output length. - Example: - - ```python - from colossalai.inference import InferEngine - from colossalai.inference.pipeline.policies import LlamaModelInferPolicy - import colossalai - from transformers import LlamaForCausalLM, LlamaTokenizer - - colossalai.launch_from_torch(config={}) - - model = LlamaForCausalLM.from_pretrained("your_path_to_model") - tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf") - # assume the model is infered with 2 pipeline stages - inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy()) - - input = ["Introduce a landmark in China ","Introduce a landmark in China "] - data = tokenizer(input, return_tensors='pt') - output = inferengine.inference([data.to('cuda').data]) - - ``` - """ def __init__( @@ -146,7 +124,7 @@ def __init__( if quant == "gptq": self.gptq_manager.post_init_gptq_buffer(self.model) - def generate(self, input_list: Union[BatchEncoding, dict]): + def generate(self, input_list: Union[list, dict]): """ Args: input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`. @@ -155,11 +133,7 @@ def generate(self, input_list: Union[BatchEncoding, dict]): out (list): a list of output data, each element is a list of token. timestamp (float): the time cost of the inference, only return when verbose is `True`. """ - assert isinstance( - input_list, (BatchEncoding, dict) - ), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}." - if isinstance(input_list, BatchEncoding): - input_list = input_list.data + out, timestamp = self.schedule.generate_step(self.model, iter([input_list])) if self.verbose: return out, timestamp diff --git a/examples/inference/run_llama_inference.py b/examples/inference/run_llama_inference.py index 7560adaad137..8151518fee14 100644 --- a/examples/inference/run_llama_inference.py +++ b/examples/inference/run_llama_inference.py @@ -1,5 +1,4 @@ import argparse -import time import torch import torch.distributed as dist @@ -11,7 +10,9 @@ def run_inference(args): - llama_model_path = args.path + llama_model_path = args.model_path + llama_tokenize_path = args.tokenizer_path + max_input_len = args.max_input_len max_output_len = args.max_output_len max_batch_size = args.batch_size @@ -20,10 +21,11 @@ def run_inference(args): pp_size = args.pp_size rank = dist.get_rank() + tokenizer = LlamaTokenizer.from_pretrained(llama_tokenize_path, padding_side="left") + tokenizer.pad_token_id = tokenizer.unk_token_id + if args.quant is None: - tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) - tokenizer.pad_token_id = tokenizer.unk_token_id - model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) + model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.unk_token_id) model = model.half() elif args.quant == "gptq": from auto_gptq import AutoGPTQForCausalLM @@ -41,8 +43,10 @@ def run_inference(args): tp_size=tp_size, pp_size=pp_size, model=model, + max_input_len=max_input_len, max_output_len=max_output_len, micro_batch_size=micro_batch_size, + quant=args.quant, ) input_tokens = { @@ -50,26 +54,9 @@ def run_inference(args): "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), } - iters = 10 - warmup = 3 - times = [] - - for i in range(iters): - torch.cuda.synchronize() - start = time.time() - outputs = engine.generate(input_tokens) - torch.cuda.synchronize() - end = time.time() - if rank == 0: - out_len = len(outputs[0]) - print("generation time {} s".format(str(end - start))) - print(out_len) - times.append((end - start) / out_len) + outputs = engine.generate(input_tokens) if rank == 0: - times = times[warmup:] - latency = sum(times) / len(times) - print("total process latency is : " + str(latency) + " s") - print("total throughput is : " + str(1 / latency * max_batch_size)) + print(tokenizer.batch_decode(outputs)) def run_tp_pipeline_inference(rank, world_size, port, args): @@ -79,7 +66,9 @@ def run_tp_pipeline_inference(rank, world_size, port, args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-p", "--path", type=str, help="Model path", required=True) + parser.add_argument("-p", "--model_path", type=str, help="Model path", required=True) + parser.add_argument("--tokenizer_path", type=str, help="Tokenizer path", required=True) + parser.add_argument( "-q", "--quant", @@ -89,12 +78,12 @@ def run_tp_pipeline_inference(rank, world_size, port, args): help="quantization type: 'gptq' or 'smoothquant'", ) parser.add_argument("--smoothquant_base_name", type=str, default=None, help="soothquant base name") - parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") - parser.add_argument("-pp", "--pp_size", type=int, default=1, help="Tensor parallel size") - parser.add_argument("-b", "--batch_size", type=int, default=64, help="Maximum batch size") - parser.add_argument("--max_input_len", type=int, default=512, help="Maximum input length") - parser.add_argument("--max_output_len", type=int, default=256, help="Maximum output length") - parser.add_argument("--micro_batch_size", type=int, default=2, help="Micro batch size") + parser.add_argument("-tp", "--tp_size", type=int, default=2, help="Tensor parallel size") + parser.add_argument("-pp", "--pp_size", type=int, default=2, help="Pipeline parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=4, help="Maximum batch size") + parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length") + parser.add_argument("--max_output_len", type=int, default=16, help="Maximum output length") + parser.add_argument("--micro_batch_size", type=int, default=1, help="Micro batch size") args = parser.parse_args() spawn(run_tp_pipeline_inference, nprocs=args.tp_size * args.pp_size, args=args) diff --git a/requirements/requirements-infer.txt b/requirements/requirements-infer.txt index 3151504df40e..46a6b41bf593 100644 --- a/requirements/requirements-infer.txt +++ b/requirements/requirements-infer.txt @@ -3,5 +3,4 @@ packaging ninja auto-gptq==0.5.0 git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8 -git+https://github.com/facebookresearch/xformers.git@main#egg=xformers git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9