From fbd8d3b87c38316fd2695b2aeb3b97ef155fd063 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 11 Oct 2023 17:54:12 +0800 Subject: [PATCH 1/4] finish input and output logic --- .../inference/dynamic_batching/io_struct.py | 8 +- colossalai/inference/manager.py | 104 +++++++++++++----- .../test_dynamic_batching/test_forward.py | 6 +- 3 files changed, 88 insertions(+), 30 deletions(-) diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 2b2739f0ae90..44ad2964a39f 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -102,17 +102,21 @@ def mark_finished_req(self, eos_id): has_new_finish = True return has_new_finish - def filter_finished(self): + def filter_finished(self)->List[Req]: """ Filter finished requests from the batch, the finished ones will be removed from 'reqs'. """ # TODO: the logic of return should be defined here. unfinished_req = [] + finished_req = [] for req in self.reqs: if not req.has_generate_finished: - unfinished_req.append(req) + unfinished_req.append(req) + else: + finished_req.append(req) self.reqs = unfinished_req self.id_to_reqs = {req.request_id: req for req in self.reqs} + return finished_req def is_clear(self): return len(self.reqs) == 0 diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 72f77406789f..678d7f0e022c 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -8,6 +8,8 @@ from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine +from transformers import AutoTokenizer +_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" class DynamicBatchManager: def __init__( @@ -54,6 +56,20 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques self.req_queue.append(req) return + def add_input(self, request_id, sampling_params, input_ids): + """ + Encode and Add new input to req queue. support one sequence input for now. + """ + prompt_ids = self.tokenizer.encode(input_ids) + prompt_len = len(prompt_ids) + if prompt_len > self.engine.max_input_len: + raise ValueError( + f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}" + ) + sampling_params.stop_sentences_to_token_ids(self.tokenizer) + self.add_req(prompt_ids, sampling_params, request_id) + return + def abort(self, request_id): if self.running_batch is not None: for req in self.running_batch.reqs: @@ -71,8 +87,9 @@ def loop_for_fwd(self): The main loop for a dynamic batching process. """ counter_count = 0 - while self.running_batch is not None or self.req_queue.waiting_req_list: - self._step() + #self.running_batch is not None or self.req_queue.waiting_req_list + while True: + yield from self._step() counter_count += 1 if self.running_batch is not None: if counter_count % self.mem_usage_interval == 0: @@ -87,6 +104,26 @@ def loop_for_fwd(self): if self.running_batch is None: time.sleep(0.1) # 10ms + def _set_tokenizer(self, tokenizer, tokenizer_name, trust_remote_code: bool = False, use_fast:bool = True,): + if tokenizer is not None: + self.tokenizer = tokenizer + else: + if "llama" in tokenizer_name.lower() and use_fast == True: + print( + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai.") + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + self.tokenizer = AutoTokenizer(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + except TypeError as e: + use_fast = False + self.tokenizer = AutoTokenizer(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + + def _step(self): """ Logic for handling requests @@ -97,14 +134,14 @@ def _step(self): if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch - self._prefill_batch(self.running_batch) + yield from self._prefill_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) - self._decode_batch(self.running_batch) + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 return @@ -112,17 +149,18 @@ def _step(self): new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) - self._prefill_batch(new_mini_batch) + yield from self._prefill_batch(new_mini_batch) if not new_mini_batch.is_clear(): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 + else: self.stats_tool.count_output_tokens(self.running_batch) - self._decode_batch(self.running_batch) + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 - + return def _init_batch(self, batch: Batch, dtype="fp16"): @@ -158,7 +196,8 @@ def _prefill_batch(self, batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - self._handle_finish_req(batch, has_new_finished_req) + yield from self._handle_finish_req(batch, has_new_finished_req) + # delete finished reqs def _decode_batch(self, batch: Batch): @@ -169,7 +208,7 @@ def _decode_batch(self, batch: Batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - self._handle_finish_req(batch, has_new_finished_req) + yield from self._handle_finish_req(batch, has_new_finished_req) def _filter_batch(self, batch: Batch): batch_id = batch.batch_id @@ -201,11 +240,13 @@ def _remove_batch(self, batch): def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: - batch.filter_finished() + finished_reqs=batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) + yield from self._output_process(finished_reqs) + def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): @@ -218,26 +259,35 @@ def _add_token_id_to_req(self, batch: Batch, req_ans): req.output_metadata_list.append(new_gen_metadata) return + def _output_process(self, finished_reqs: List[Req]): + """ + Process the output of a batch. + """ + for req in finished_reqs: + #output = self.tokenizer.decode(req.output_ids) + output =req.output_ids + yield output, req.request_id, req.output_metadata_list + def clean_up(self): # this logic should be implemented in the future. pass def start_dynamic_batching(args, tp_engine, waiting_req_list): - # try: - batch_manager = DynamicBatchManager( - tp_engine=tp_engine, - max_total_token_num=args.max_total_token_num, - batch_max_tokens=args.batch_max_tokens, - eos_id=args.eos_id, - log_stats=not args.disable_log_stats, - log_stats_interval=args.log_stats_interval, - waiting_req_list=waiting_req_list, - ) - - # except Exception: - # batch_manager.clean_up() - # raise - - batch_manager.loop_for_fwd() - return + try: + batch_manager = DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + except Exception: + batch_manager.clean_up() + raise + + generator = batch_manager.loop_for_fwd() + return batch_manager,generator diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py index 63df491e5b52..7dcb60e73d9f 100644 --- a/tests/test_infer/test_dynamic_batching/test_forward.py +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -50,7 +50,11 @@ def run(): shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) + manager, result_generator = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) + for result in result_generator: + print(result) + + def check_dynamic_forward(rank, world_size, port): From e6cb350a217f00e72ee9e1084f079c8d39da0012 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 11 Oct 2023 18:38:00 +0800 Subject: [PATCH 2/4] add generate --- colossalai/inference/manager.py | 20 ++++++++++++------- .../test_dynamic_batching/test_forward.py | 8 +++++--- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 678d7f0e022c..23c37a03ba2e 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,5 +1,6 @@ import time from typing import List +import asyncio from .dynamic_batching.infer_batch import InferBatch from .dynamic_batching.io_struct import Batch, Req @@ -62,6 +63,7 @@ def add_input(self, request_id, sampling_params, input_ids): """ prompt_ids = self.tokenizer.encode(input_ids) prompt_len = len(prompt_ids) + print(prompt_ids) if prompt_len > self.engine.max_input_len: raise ValueError( f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}" @@ -104,7 +106,7 @@ def loop_for_fwd(self): if self.running_batch is None: time.sleep(0.1) # 10ms - def _set_tokenizer(self, tokenizer, tokenizer_name, trust_remote_code: bool = False, use_fast:bool = True,): + def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): if tokenizer is not None: self.tokenizer = tokenizer else: @@ -118,10 +120,10 @@ def _set_tokenizer(self, tokenizer, tokenizer_name, trust_remote_code: bool = Fa tokenizer_name = _FAST_LLAMA_TOKENIZER try: - self.tokenizer = AutoTokenizer(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) except TypeError as e: use_fast = False - self.tokenizer = AutoTokenizer(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) def _step(self): @@ -264,14 +266,19 @@ def _output_process(self, finished_reqs: List[Req]): Process the output of a batch. """ for req in finished_reqs: - #output = self.tokenizer.decode(req.output_ids) - output =req.output_ids + output = self.tokenizer.decode(req.output_ids) yield output, req.request_id, req.output_metadata_list def clean_up(self): # this logic should be implemented in the future. pass + def generate(self,request_id,prompt_id,sampling_params): + """ + Generate the output of a request. + """ + self.add_input(request_id,prompt_id,sampling_params) + return self.loop_for_fwd() def start_dynamic_batching(args, tp_engine, waiting_req_list): try: @@ -289,5 +296,4 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list): batch_manager.clean_up() raise - generator = batch_manager.loop_for_fwd() - return batch_manager,generator + return batch_manager diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py index 7dcb60e73d9f..1894bdb030ab 100644 --- a/tests/test_infer/test_dynamic_batching/test_forward.py +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -42,15 +42,17 @@ def run(): waiting_list.append(req2) waiting_list.append(req3) waiting_list.append(req4) - + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) - model = LlamaForCausalLM(llama_config) + model = LlamaForCausalLM.from_pretrained('/mnt/vepfs/lczyh/models/llama-7b-hf/') model = model.half() shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - manager, result_generator = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) + manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) + manager._set_tokenizer(tokenizer_name = model.__class__.__name__) + result_generator = manager.loop_for_fwd() for result in result_generator: print(result) From a4d1e334f659c02ee6be76357eeaa1cda1eaa336 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 12 Oct 2023 10:55:43 +0800 Subject: [PATCH 3/4] test forward --- tests/test_infer/test_dynamic_batching/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py index 1894bdb030ab..764bea9c297d 100644 --- a/tests/test_infer/test_dynamic_batching/test_forward.py +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -44,7 +44,7 @@ def run(): waiting_list.append(req4) llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) - model = LlamaForCausalLM.from_pretrained('/mnt/vepfs/lczyh/models/llama-7b-hf/') + model = LlamaForCausalLM.from_pretrained(llama_config) model = model.half() shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) From 75b18e363b090e58c24bd3aa2f09e38561a5fc1d Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 12 Oct 2023 15:56:38 +0800 Subject: [PATCH 4/4] 1 --- colossalai/inference/manager.py | 20 +++++++---- colossalai/inference/test_async.py | 33 +++++++++++++++++++ .../test_dynamic_batching/test_forward.py | 2 +- 3 files changed, 48 insertions(+), 7 deletions(-) create mode 100644 colossalai/inference/test_async.py diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 23c37a03ba2e..453570c7ec3e 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -63,7 +63,6 @@ def add_input(self, request_id, sampling_params, input_ids): """ prompt_ids = self.tokenizer.encode(input_ids) prompt_len = len(prompt_ids) - print(prompt_ids) if prompt_len > self.engine.max_input_len: raise ValueError( f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}" @@ -84,14 +83,15 @@ def abort(self, request_id): req.aborted = True return - def loop_for_fwd(self): + async def loop_for_fwd(self): """ The main loop for a dynamic batching process. """ counter_count = 0 #self.running_batch is not None or self.req_queue.waiting_req_list while True: - yield from self._step() + async for item in self._step(): + yield item counter_count += 1 if self.running_batch is not None: if counter_count % self.mem_usage_interval == 0: @@ -261,7 +261,7 @@ def _add_token_id_to_req(self, batch: Batch, req_ans): req.output_metadata_list.append(new_gen_metadata) return - def _output_process(self, finished_reqs: List[Req]): + async def _output_process(self, finished_reqs: List[Req]): """ Process the output of a batch. """ @@ -273,12 +273,12 @@ def clean_up(self): # this logic should be implemented in the future. pass - def generate(self,request_id,prompt_id,sampling_params): + async def generate(self,request_id,prompt_id,sampling_params): """ Generate the output of a request. """ self.add_input(request_id,prompt_id,sampling_params) - return self.loop_for_fwd() + def start_dynamic_batching(args, tp_engine, waiting_req_list): try: @@ -295,5 +295,13 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list): except Exception: batch_manager.clean_up() raise + + batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__) + prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world")) + + asyncio.run(prod_task) + + for item in batch_manager.loop_for_fwd(): + print(item) return batch_manager diff --git a/colossalai/inference/test_async.py b/colossalai/inference/test_async.py new file mode 100644 index 000000000000..08720f36da22 --- /dev/null +++ b/colossalai/inference/test_async.py @@ -0,0 +1,33 @@ +import asyncio + +shared_list = [] + +async def producer(): + for i in range(5): + await asyncio.sleep(1) # 模拟异步获取数据的操作 + shared_list.append(i) + print(f"Produced {i}") + +async def consumer(): + last_index = 0 + while True: + await asyncio.sleep(0.5) # 为了不使循环过于紧凑,增加了小的延迟 + if last_index < len(shared_list): + item = shared_list[last_index] + print(f"Consumed {item}") + yield item + last_index += 1 + +async def main(): + # 创建生产者和消费者任务 + prod_task = asyncio.create_task(producer()) + + # 等待生产者任务完成 + await prod_task + + async for data in consumer(): + print(data) + # 为了示例的目的,我们只等待一段时间,然后停止消费者 + await asyncio.sleep(5) + +asyncio.run(main()) diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py index 764bea9c297d..ca6401259831 100644 --- a/tests/test_infer/test_dynamic_batching/test_forward.py +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -44,7 +44,7 @@ def run(): waiting_list.append(req4) llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) - model = LlamaForCausalLM.from_pretrained(llama_config) + model = LlamaForCausalLM(llama_config) model = model.half() shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True)