From 6eb051657df9a5ac50d5eaf84253e5b7c825fa86 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Mon, 11 Sep 2023 16:47:16 +0800 Subject: [PATCH 1/3] remove using orig model in engine --- .../inference/tensor_parallel/engine.py | 35 ++++++++----------- examples/inference/bench_bloom.py | 1 - examples/inference/bench_llama.py | 1 - tests/test_infer/test_bloom_infer.py | 1 - tests/test_infer/test_infer_engine.py | 9 +++-- tests/test_infer/test_llama_infer.py | 1 - 6 files changed, 19 insertions(+), 29 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index c02ccb6e54ce..2cb115c30770 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -36,7 +36,6 @@ class TPInferEngine: >>> generate_kwargs = ... >>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) >>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - >>> infer_engine.optimize_model() >>> outputs = infer_engine.generate(input_ids, **generate_kwargs) """ @@ -48,11 +47,6 @@ def __init__(self, max_output_len: int, dtype: torch.dtype = torch.float16, device: str = 'cuda') -> None: - self.model = model - self.model = self.model.to(device) - self.shard_config = shard_config - self.sharded_model = None - self.max_batch_size = max_batch_size self.max_input_len = max_input_len self.max_output_len = max_output_len @@ -65,13 +59,18 @@ def __init__(self, self.dtype = dtype - self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads - self.head_num = self.model.config.num_attention_heads - self.layer_num = self.model.config.num_hidden_layers + self.head_dim = model.config.hidden_size // model.config.num_attention_heads + self.head_num = model.config.num_attention_heads + self.layer_num = model.config.num_hidden_layers self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.cache_manager = None + self.shard_config = shard_config + self.sharded_model = None + # optimize the original model by sharding with ShardFormer + self._optimize_model(model=model.to(device)) + def _init_manager(self) -> None: assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig" assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" @@ -79,7 +78,7 @@ def _init_manager(self) -> None: self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num) - def optimize_model(self) -> None: + def _optimize_model(self, model: nn.Module) -> None: """ Optimize the original model by sharding with ShardFormer. In further generation, use the sharded model instead of original model. @@ -88,8 +87,7 @@ def optimize_model(self) -> None: assert self.shard_config.inference_only is True shardformer = ShardFormer(shard_config=self.shard_config) self._prepare_with_shard_config(shard_config=self.shard_config) - self._shard_model_by(shardformer) - self.model = None + self._shard_model_by(shardformer, model) def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig: """ Prepare the engine with a given ShardConfig. @@ -119,14 +117,14 @@ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) return shard_config - def _shard_model_by(self, shardformer: ShardFormer) -> None: + def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: """ Shard original model by the given ShardFormer and store the sharded model. """ assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \ "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" - model_name = self.model.__class__.__name__ + model_name = model.__class__.__name__ assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." - policy = get_autopolicy(self.model, inference_only=True) - self.sharded_model, _ = shardformer.optimize(self.model, policy) + policy = get_autopolicy(model, inference_only=True) + self.sharded_model, _ = shardformer.optimize(model, policy) self.sharded_model = self.sharded_model.cuda() @property @@ -152,10 +150,7 @@ def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], if 'max_new_tokens' not in generate_kwargs: generate_kwargs.update(max_new_tokens=self.max_output_len) - if self.sharded_model is not None: - return self._generate_by_set_infer_state(input_tokens, **generate_kwargs) - - return self.model.generate(**input_tokens, **generate_kwargs) + return self._generate_by_set_infer_state(input_tokens, **generate_kwargs) def prepare_batch_state(self, inputs) -> BatchInferState: """ diff --git a/examples/inference/bench_bloom.py b/examples/inference/bench_bloom.py index 20a6729abc21..67ff13bb5f5e 100644 --- a/examples/inference/bench_bloom.py +++ b/examples/inference/bench_bloom.py @@ -48,7 +48,6 @@ def bench_bloom(args): # To benchmark torch original, comment out the line of optimizing model shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) - infer_engine.optimize_model() # prepare data for generation generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index b8ee8eb4f69d..d2016a4587e6 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -72,7 +72,6 @@ def run_llama_test(args): shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) - infer_engine.optimize_model() generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) input_tokens = { diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index f26f05abeb79..42b000e90e30 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -35,7 +35,6 @@ def run(test_config): shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, inference_only=True) infer_engine = TPInferEngine(orig_model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - infer_engine.optimize_model() generate_kwargs = dict(do_sample=False) outputs = infer_engine.generate(data, **generate_kwargs) diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index b1b3b57068c1..5975e670d133 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -76,7 +76,7 @@ def __init__(self): @pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") -def test_orig_generate(): +def test_generate(): input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN)) model_config = LlamaConfig() @@ -108,7 +108,6 @@ def run(test_config): shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - infer_engine.optimize_model() assert infer_engine.cache_manager is not None assert infer_engine.tp_size == TP_SIZE @@ -127,11 +126,11 @@ def check_engine(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() -def test_engine_infer(): +def test_engine_tp(): spawn(check_engine, TP_SIZE) if __name__ == '__main__': test_prepare_data() - test_orig_generate() - test_engine_infer() + test_generate() + test_engine_tp() diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 7dfb63e16e8e..186a490ed8d5 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -61,7 +61,6 @@ def run_llama_test(test_config): shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, inference_only=True) infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - infer_engine.optimize_model() generate_kwargs = dict(do_sample=False) outputs = infer_engine.generate(data, **generate_kwargs) From afcb08d0ffaf2318c04f5a2034bfedb7e9efd496 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Mon, 11 Sep 2023 17:16:24 +0800 Subject: [PATCH 2/3] revise inference tests --- tests/test_infer/test_bloom_infer.py | 2 - tests/test_infer/test_infer_engine.py | 82 +++++++-------------------- tests/test_infer/test_llama_infer.py | 2 - 3 files changed, 20 insertions(+), 66 deletions(-) diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index 42b000e90e30..8ecabf69ecf3 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -2,9 +2,7 @@ import pytest import torch -import torch.distributed as dist from packaging import version -from transformers import AutoTokenizer, BloomForCausalLM import colossalai from colossalai.inference.tensor_parallel import TPInferEngine diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index 5975e670d133..cc3cdd2b501b 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -11,7 +11,7 @@ from colossalai.inference.tensor_parallel import TPInferEngine from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn TP_SIZE = 2 @@ -22,31 +22,25 @@ CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") -def test_prepare_data(): - # dummy module used for testing - class DummyModule(nn.Module): - - def __init__(self, config): - super(DummyModule, self).__init__() - self.config = config - - def forward(self, x): - return x - - # dummy config used for testing - class DummyModelConfig: - - def __init__(self): - self.hidden_size = 4096 - self.num_attention_heads = 32 - self.num_hidden_layers = 8 +@parameterize('test_config', [{ + 'tp_size': TP_SIZE, +}]) +def run(test_config): + model_config = BloomConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4) + model = BloomForCausalLM(model_config) + model = model.half() + model.to(torch.cuda.current_device()) - dummy_config = DummyModelConfig() - model = DummyModule(dummy_config) - shard_config = ShardConfig(enable_tensor_parallelism=False) + # 1. check TPInferEngine init and model optimization + shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, + inference_only=True) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + assert infer_engine.cache_manager is not None + assert infer_engine.tp_size == TP_SIZE + assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE + + # 2. check data preparation input_ids_list = [[80540, 15473, 3331, 11970, 90472, 361, 61335], [80540, 15473, 3331, 11970], [80540, 15473, 3331, 11970], [80540, 15473]] batch_size = len(input_ids_list) @@ -74,48 +68,14 @@ def __init__(self): # assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc) # assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc) - -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") -def test_generate(): + # 3. check optimized model generate input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN)) - - model_config = LlamaConfig() - model = LlamaForCausalLM(model_config) - model = model.half() - model.to(torch.cuda.current_device()) - - shard_config = ShardConfig(enable_tensor_parallelism=False) - - # init TPInferEngine - infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - - # original model generate generate_kwargs = dict(do_sample=False) infer_engine.generate(input_ids, **generate_kwargs) torch.cuda.empty_cache() -@parameterize('test_config', [{ - 'tp_size': TP_SIZE, -}]) -def run(test_config): - model_config = BloomConfig() - model = BloomForCausalLM(model_config) - model = model.half() - model.to(torch.cuda.current_device()) - - shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, - inference_only=True) - infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - - assert infer_engine.cache_manager is not None - assert infer_engine.tp_size == TP_SIZE - assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE - - torch.cuda.empty_cache() - - def check_engine(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') @@ -126,11 +86,9 @@ def check_engine(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() -def test_engine_tp(): +def test_engine(): spawn(check_engine, TP_SIZE) if __name__ == '__main__': - test_prepare_data() - test_generate() - test_engine_tp() + test_engine() diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 186a490ed8d5..aa8874ea4cb0 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -3,9 +3,7 @@ import pytest import torch -import torch.distributed as dist from packaging import version -from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai from colossalai.inference.tensor_parallel.engine import TPInferEngine From 57d94c92f9961c48a224be623929cc0850606674 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Mon, 11 Sep 2023 17:32:45 +0800 Subject: [PATCH 3/3] trivial: rename --- colossalai/inference/tensor_parallel/engine.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 2cb115c30770..a5a55702ade0 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -67,7 +67,7 @@ def __init__(self, self.cache_manager = None self.shard_config = shard_config - self.sharded_model = None + self.model = None # optimize the original model by sharding with ShardFormer self._optimize_model(model=model.to(device)) @@ -124,8 +124,8 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: model_name = model.__class__.__name__ assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." policy = get_autopolicy(model, inference_only=True) - self.sharded_model, _ = shardformer.optimize(model, policy) - self.sharded_model = self.sharded_model.cuda() + self.model, _ = shardformer.optimize(model, policy) + self.model = self.model.cuda() @property def supported_models(self) -> List[str]: @@ -234,7 +234,7 @@ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch """ # for testing, always use sharded model - assert self.sharded_model is not None, "sharded model does not exist" + assert self.model is not None, "sharded model does not exist" batch_infer_state = self.prepare_batch_state(input_tokens) assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit" @@ -243,14 +243,14 @@ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch # NOTE this is not a preferable way to pass BatchInferState during inference # we might want to rewrite generate function (e.g. _generate_by_pass_infer_state) # and pass BatchInferState via model forward - model = self.sharded_model + model = self.model if isinstance(model, LlamaForCausalLM): - model = self.sharded_model.model + model = self.model.model elif isinstance(model, BloomForCausalLM): - model = self.sharded_model.transformer + model = self.model.transformer setattr(model, 'infer_state', batch_infer_state) - outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False) + outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False) # NOTE In future development, we're going to let the scheduler to handle the cache, # instead of freeing space explicitly at the end of generation