From ffd45b6abf27898ddb8db3d066536371760229a7 Mon Sep 17 00:00:00 2001 From: Cuiqing Li Date: Sat, 14 Oct 2023 12:59:32 +0800 Subject: [PATCH] Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. --- .../inference/dynamic_batching/io_struct.py | 8 +- colossalai/inference/manager.py | 120 ++++-------------- colossalai/inference/test_async.py | 33 ----- .../test_dynamic_batching/test_forward.py | 10 +- 4 files changed, 32 insertions(+), 139 deletions(-) delete mode 100644 colossalai/inference/test_async.py diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 44ad2964a39f..2b2739f0ae90 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -102,21 +102,17 @@ def mark_finished_req(self, eos_id): has_new_finish = True return has_new_finish - def filter_finished(self)->List[Req]: + def filter_finished(self): """ Filter finished requests from the batch, the finished ones will be removed from 'reqs'. """ # TODO: the logic of return should be defined here. unfinished_req = [] - finished_req = [] for req in self.reqs: if not req.has_generate_finished: - unfinished_req.append(req) - else: - finished_req.append(req) + unfinished_req.append(req) self.reqs = unfinished_req self.id_to_reqs = {req.request_id: req for req in self.reqs} - return finished_req def is_clear(self): return len(self.reqs) == 0 diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 453570c7ec3e..72f77406789f 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,6 +1,5 @@ import time from typing import List -import asyncio from .dynamic_batching.infer_batch import InferBatch from .dynamic_batching.io_struct import Batch, Req @@ -9,8 +8,6 @@ from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine -from transformers import AutoTokenizer -_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" class DynamicBatchManager: def __init__( @@ -57,20 +54,6 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques self.req_queue.append(req) return - def add_input(self, request_id, sampling_params, input_ids): - """ - Encode and Add new input to req queue. support one sequence input for now. - """ - prompt_ids = self.tokenizer.encode(input_ids) - prompt_len = len(prompt_ids) - if prompt_len > self.engine.max_input_len: - raise ValueError( - f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}" - ) - sampling_params.stop_sentences_to_token_ids(self.tokenizer) - self.add_req(prompt_ids, sampling_params, request_id) - return - def abort(self, request_id): if self.running_batch is not None: for req in self.running_batch.reqs: @@ -83,15 +66,13 @@ def abort(self, request_id): req.aborted = True return - async def loop_for_fwd(self): + def loop_for_fwd(self): """ The main loop for a dynamic batching process. """ counter_count = 0 - #self.running_batch is not None or self.req_queue.waiting_req_list - while True: - async for item in self._step(): - yield item + while self.running_batch is not None or self.req_queue.waiting_req_list: + self._step() counter_count += 1 if self.running_batch is not None: if counter_count % self.mem_usage_interval == 0: @@ -106,26 +87,6 @@ async def loop_for_fwd(self): if self.running_batch is None: time.sleep(0.1) # 10ms - def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): - if tokenizer is not None: - self.tokenizer = tokenizer - else: - if "llama" in tokenizer_name.lower() and use_fast == True: - print( - "For some LLaMA-based models, initializing the fast tokenizer may " - "take a long time. To eliminate the initialization time, consider " - f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " - "tokenizer. This is done automatically in Colossalai.") - - tokenizer_name = _FAST_LLAMA_TOKENIZER - - try: - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - except TypeError as e: - use_fast = False - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - - def _step(self): """ Logic for handling requests @@ -136,14 +97,14 @@ def _step(self): if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch - yield from self._prefill_batch(self.running_batch) + self._prefill_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) - yield from self._decode_batch(self.running_batch) + self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 return @@ -151,18 +112,17 @@ def _step(self): new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) - yield from self._prefill_batch(new_mini_batch) + self._prefill_batch(new_mini_batch) if not new_mini_batch.is_clear(): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 - else: self.stats_tool.count_output_tokens(self.running_batch) - yield from self._decode_batch(self.running_batch) + self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 - + return def _init_batch(self, batch: Batch, dtype="fp16"): @@ -198,8 +158,7 @@ def _prefill_batch(self, batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - yield from self._handle_finish_req(batch, has_new_finished_req) - + self._handle_finish_req(batch, has_new_finished_req) # delete finished reqs def _decode_batch(self, batch: Batch): @@ -210,7 +169,7 @@ def _decode_batch(self, batch: Batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - yield from self._handle_finish_req(batch, has_new_finished_req) + self._handle_finish_req(batch, has_new_finished_req) def _filter_batch(self, batch: Batch): batch_id = batch.batch_id @@ -242,13 +201,11 @@ def _remove_batch(self, batch): def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: - finished_reqs=batch.filter_finished() + batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) - yield from self._output_process(finished_reqs) - def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): @@ -261,47 +218,26 @@ def _add_token_id_to_req(self, batch: Batch, req_ans): req.output_metadata_list.append(new_gen_metadata) return - async def _output_process(self, finished_reqs: List[Req]): - """ - Process the output of a batch. - """ - for req in finished_reqs: - output = self.tokenizer.decode(req.output_ids) - yield output, req.request_id, req.output_metadata_list - def clean_up(self): # this logic should be implemented in the future. pass - async def generate(self,request_id,prompt_id,sampling_params): - """ - Generate the output of a request. - """ - self.add_input(request_id,prompt_id,sampling_params) - def start_dynamic_batching(args, tp_engine, waiting_req_list): - try: - batch_manager = DynamicBatchManager( - tp_engine=tp_engine, - max_total_token_num=args.max_total_token_num, - batch_max_tokens=args.batch_max_tokens, - eos_id=args.eos_id, - log_stats=not args.disable_log_stats, - log_stats_interval=args.log_stats_interval, - waiting_req_list=waiting_req_list, - ) - - except Exception: - batch_manager.clean_up() - raise - - batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__) - prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world")) - - asyncio.run(prod_task) - - for item in batch_manager.loop_for_fwd(): - print(item) - - return batch_manager + # try: + batch_manager = DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + # except Exception: + # batch_manager.clean_up() + # raise + + batch_manager.loop_for_fwd() + return diff --git a/colossalai/inference/test_async.py b/colossalai/inference/test_async.py deleted file mode 100644 index 08720f36da22..000000000000 --- a/colossalai/inference/test_async.py +++ /dev/null @@ -1,33 +0,0 @@ -import asyncio - -shared_list = [] - -async def producer(): - for i in range(5): - await asyncio.sleep(1) # 模拟异步获取数据的操作 - shared_list.append(i) - print(f"Produced {i}") - -async def consumer(): - last_index = 0 - while True: - await asyncio.sleep(0.5) # 为了不使循环过于紧凑,增加了小的延迟 - if last_index < len(shared_list): - item = shared_list[last_index] - print(f"Consumed {item}") - yield item - last_index += 1 - -async def main(): - # 创建生产者和消费者任务 - prod_task = asyncio.create_task(producer()) - - # 等待生产者任务完成 - await prod_task - - async for data in consumer(): - print(data) - # 为了示例的目的,我们只等待一段时间,然后停止消费者 - await asyncio.sleep(5) - -asyncio.run(main()) diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py index ca6401259831..63df491e5b52 100644 --- a/tests/test_infer/test_dynamic_batching/test_forward.py +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -42,7 +42,7 @@ def run(): waiting_list.append(req2) waiting_list.append(req3) waiting_list.append(req4) - + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) model = LlamaForCausalLM(llama_config) model = model.half() @@ -50,13 +50,7 @@ def run(): shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) - manager._set_tokenizer(tokenizer_name = model.__class__.__name__) - result_generator = manager.loop_for_fwd() - for result in result_generator: - print(result) - - + start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) def check_dynamic_forward(rank, world_size, port):