diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 2fb76d3e5e58..6a3f961f7054 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.nn as nn @@ -7,7 +7,6 @@ from transformers.generation.stopping_criteria import StoppingCriteriaList from transformers.tokenization_utils_base import BatchEncoding -from colossalai.cluster import ProcessGroupMesh from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.auto_policy import get_autopolicy @@ -29,6 +28,7 @@ def __init__(self, dtype: torch.dtype = torch.float16, device: str = 'cuda') -> None: self.model = model + self.model = self.model.to(device) self.sharded_model = None self.max_batch_size = max_batch_size @@ -57,7 +57,18 @@ def _init_manager(self) -> None: self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num) - def prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig: + def optimize_model(self, config: Optional[Dict[Any, Any]] = None) -> None: + """ Apply shardformer to optimize the model. In future generation, use sharded model instead of original model. """ + tp_size = 1 if config is None else config.get('tp_size', 1) + # NOTE we will change to use an inference config later with additional attrs we want + # tp_size = getattr(config, 'tp_size', 1) + shard_config = ShardConfig(enable_tensor_parallelism=True if tp_size > 1 else False, inference_only=True) + shardformer = ShardFormer(shard_config=shard_config) + self._prepare_with_shard_config(shard_config=shard_config) + self._shard_model_by(shardformer) + self.model = None + + def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig: """ Prepare the engine with a given ShardConfig, or create a default one with tp size 1 """ self.tp_size = 1 if shard_config is None: @@ -80,7 +91,7 @@ def prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) return shard_config - def shard_model_by(self, shardformer: ShardFormer) -> None: + def _shard_model_by(self, shardformer: ShardFormer) -> None: """ Shard the model and store the sharded model by given ShardFormer """ assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \ "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" @@ -100,11 +111,13 @@ def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], for t in input_tokens: if torch.is_tensor(input_tokens[t]): input_tokens[t] = input_tokens[t].cuda() + if 'max_new_tokens' not in generate_kwargs: + generate_kwargs.update(max_new_tokens=self.max_output_len) if self.sharded_model is not None: return self.generate_by_set_infer_state(input_tokens, **generate_kwargs) - return self.model.generate(**input_tokens, **generate_kwargs) + return self.model.generate(input_tokens.get('input_ids'), **generate_kwargs) @torch.no_grad() def generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch.Tensor: @@ -135,9 +148,12 @@ def generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch. model = self.sharded_model.transformer setattr(model, 'infer_state', batch_infer_state) - generate_kwargs.update(max_new_tokens=self.max_output_len) outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False) + # NOTE In future development, we're going to let the scheduler to handle the cache, + # instead of freeing space explicitly at the end of generation + self.cache_manager.free_all() + return outputs def prepare_batch_state(self, inputs) -> BatchInferState: diff --git a/examples/inference/bench_bloom.py b/examples/inference/bench_bloom.py index ce4396b11ba5..c07202ef882b 100644 --- a/examples/inference/bench_bloom.py +++ b/examples/inference/bench_bloom.py @@ -48,17 +48,11 @@ def bench_bloom(test_config): tokenizer.pad_token = tokenizer.eos_token model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) model = model.half() - # To benchmark torch original, uncommment the following line - # model.to(torch.cuda.current_device()) - # init TPInferEngine and shard original model by shardformer - # To benchmark torch original, comment out lines of creating, preparing, and sharding by the shardformer + # init TPInferEngine and shard the original model + # To benchmark torch original, comment out lines of optimizing model infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, - inference_only=True) - shardformer = ShardFormer(shard_config=shard_config) - infer_engine.prepare_with_shard_config(shard_config) - infer_engine.shard_model_by(shardformer) + infer_engine.optimize_model(test_config) # prepare data for generation batch_size = MAX_BATCH_SIZE @@ -78,10 +72,9 @@ def bench_bloom(test_config): for i in range(iters): torch.cuda.synchronize() start = time.time() - outputs = infer_engine.generate(input_tokens, generate_kwargs) + outputs = infer_engine.generate(input_tokens, **generate_kwargs) torch.cuda.synchronize() end = time.time() - infer_engine.cache_manager.free_all() out_len = outputs.shape[1] print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s") times.append((end - start) / (out_len - input_len)) diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index 1aabd340aedd..c1ece952b099 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -77,12 +77,8 @@ def run_llama_test(test_config): model_config = model.config - infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) - shardformer = ShardFormer(shard_config=shard_config) - - infer_engine.prepare_with_shard_config(shard_config) - infer_engine.shard_model_by(shardformer) + infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine.optimize_model(test_config) batch_size = 2 max_new_tokens = 128 @@ -100,13 +96,12 @@ def run_llama_test(test_config): for i in range(iters): torch.cuda.synchronize() start = time.time() - outputs = infer_engine.generate(input_tokens, generate_kwargs) + outputs = infer_engine.generate(input_tokens, **generate_kwargs) torch.cuda.synchronize() end = time.time() out_len = outputs.shape[1] print("generation time {} s".format(str(end - start))) times.append((end - start) / (out_len - input_len)) - infer_engine.cache_manager.free_all() print("outputs, ", len(outputs)) outputs = tokenizer.batch_decode(outputs) diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index 1f01460994d9..eb55d7d40778 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -4,13 +4,12 @@ import torch import torch.distributed as dist from packaging import version -from transformers import AutoModelForCausalLM, AutoTokenizer, BloomForCausalLM +from transformers import AutoTokenizer, BloomForCausalLM import colossalai from colossalai.inference.tensor_parallel import TPInferEngine from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn TP_SIZE = 2 MAX_BATCH_SIZE = 4 @@ -20,34 +19,38 @@ CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') -def run(): - model_path = "/home/lczyh/data3/models/bloom-7b1" +@parameterize('test_config', [{ + 'tp_size': TP_SIZE, +}]) +def run(test_config): + model_path = "/data3/models/bloom-7b1" if os.path.isdir(model_path) is False: return tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer.pad_token = tokenizer.eos_token - text = "Introduce some landmarks in Beijing" - input_ids = tokenizer.batch_encode_plus([text], return_tensors='pt') + text1 = "Introduce some landmarks in Beijing" + text2 = "how is weather today?" + input_ids = tokenizer.batch_encode_plus([text1, text2], return_tensors='pt', padding=True) model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) model = model.half() - model.to(torch.cuda.current_device()) - - shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) - shardformer = ShardFormer(shard_config=shard_config) infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - infer_engine.prepare_with_shard_config(shard_config=shard_config) - infer_engine.shard_model_by(shardformer) + infer_engine.optimize_model(test_config) generate_kwargs = dict(do_sample=False) outputs = infer_engine.generate(input_ids, **generate_kwargs) + assert outputs is not None + if not dist.is_initialized() or dist.get_rank() == 0: - output_text = tokenizer.decode(outputs[0]) - print(output_text) + # output_text = tokenizer.decode(outputs[0]) + # print(output_text) + for o in outputs: + output_text = tokenizer.decode(o) + # print(output_text) def check_engine(rank, world_size, port): diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index bc96ee137353..b4feb10c4573 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -85,9 +85,8 @@ def test_orig_generate(): shard_config = ShardConfig(enable_tensor_parallelism=False) - # init TPInferEngine and + # init TPInferEngine infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - infer_engine.prepare_with_shard_config(shard_config) # original model generate generate_kwargs = dict(do_sample=False) @@ -96,18 +95,17 @@ def test_orig_generate(): torch.cuda.empty_cache() -def run(): +@parameterize('test_config', [{ + 'tp_size': TP_SIZE, +}]) +def run(test_config): model_config = BloomConfig() model = BloomForCausalLM(model_config) model = model.half() model.to(torch.cuda.current_device()) - shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) - shardformer = ShardFormer(shard_config=shard_config) - infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - infer_engine.prepare_with_shard_config(shard_config=shard_config) - infer_engine.shard_model_by(shardformer) + infer_engine.optimize_model(test_config) assert infer_engine.cache_manager is not None assert infer_engine.tp_size == TP_SIZE diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 986f70633289..3b9317cbceb6 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -46,7 +46,10 @@ def init_to_get_rotary(self, base=10000): return -def run_llama_test(): +@parameterize('test_config', [{ + 'tp_size': TPSIZE, +}]) +def run_llama_test(test_config): llama_model_path = "/data/scratch/llama-7b-hf" if os.path.isdir(llama_model_path) is False: @@ -61,19 +64,14 @@ def run_llama_test(): text = ["how is weather today?", "i am "] input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True, device='cuda') - #print("input ids ", input_ids) - infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) - shardformer = ShardFormer(shard_config=shard_config) - - infer_engine.prepare_with_shard_config(shard_config) - infer_engine.shard_model_by(shardformer) + infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine.optimize_model(test_config) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) outputs = infer_engine.generate(input_ids, **generate_kwargs) - #print("outputs.shape: ", outputs.shape) - #print("outputs: ", outputs) + assert outputs is not None + if not dist.is_initialized() or dist.get_rank() == 0: for o in outputs: output_text = tokenizer.decode(o)