-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[Inference] Dynamic Batching Inference, online and offline #4953
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
e0757c3
[inference] Dynamic Batching for Single and Multiple GPUs (#4831)
CjhHa1 fced140
[inference] Async dynamic batching (#4894)
CjhHa1 fbf3c09
[inference]Re push async dynamic batching (#4901)
CjhHa1 d509e79
Revert "[inference]Re push async dynamic batching (#4901)" (#4905)
CjhHa1 ec004fe
Revert "[inference] Async dynamic batching (#4894)"
isky-cd 78cd937
Revert "[inference] Async dynamic batching (#4894)" (#4909)
tiandiao123 d97290a
Add Ray Distributed Environment Init Scripts
isky-cd 8483393
fix conflict
isky-cd f589e97
support DynamicBatchManager base function
isky-cd c070050
revert _set_tokenizer version
isky-cd 5deb95c
add driver async generate
isky-cd 306ef77
add async test
isky-cd 632f0e1
fix bugs in test_ray_dist.py
isky-cd 0b2fe51
add get_tokenizer.py
isky-cd cd843ac
fix code style
isky-cd 8c9ad51
fix bugs about No module named 'pydantic' in ci test
isky-cd 8d0cc6b
fix bugs in ci test
isky-cd acdd751
fix bugs in ci test
isky-cd 8a761bd
fix bugs in ci test
isky-cd 56f75c4
[infer]Add Ray Distributed Environment Init Scripts (#4911)
isky-cd c76fd68
support dynamic batch for bloom model and is_running function
isky-cd f41ccdd
fix conflict
isky-cd fca12b8
Merge pull request #4933 from yuehuayingxueluo/ray_dist_init_branch
isky-cd 4ea9fbe
[Inference]Test for new Async engine (#4935)
CjhHa1 3f6af12
add assertion for config (#4947)
CjhHa1 4867561
[Inference] Finish dynamic batching offline test (#4948)
CjhHa1 285fc30
fix bugs
CjhHa1 d5d2c94
fix quant
CjhHa1 ed86584
add default
CjhHa1 4bffb8b
fix
CjhHa1 77adc2e
fix some bugs
CjhHa1 dcb51b4
fix some bugs
CjhHa1 afae53b
fix
c477266
fix bug
f99eba2
fix bugs
4c3ea40
reset param
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| import asyncio | ||
|
|
||
| from colossalai.inference.dynamic_batching.ray_dist_init import Driver | ||
|
|
||
| from .dynamic_batching.io_struct import RequestOutput | ||
| from .dynamic_batching.sampling_params import SamplingParams | ||
|
|
||
|
|
||
| class RequestTracker: | ||
| """ | ||
| A class for trace down all the requests, abstraction for async | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| self._requests: asyncio.Queue[str] = asyncio.Queue() | ||
| self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue() | ||
| self.new_requests_event = None | ||
|
|
||
| def __contains__(self, item): | ||
| return item in self._requests | ||
|
|
||
| def init_event(self): | ||
| self.new_requests_event = asyncio.Event() | ||
|
|
||
| def add_request(self, request_id: str): | ||
| """Add a request to be sent to the engine on the next background | ||
| loop iteration.""" | ||
| self._requests.put_nowait(request_id) | ||
| self.new_requests_event.set() # NOTE: we may find a better way to clear this event | ||
|
|
||
| def add_stop(self): | ||
| """ | ||
| Add a StopIteration flag to stop async generator. | ||
| """ | ||
| self._finished_requests.put_nowait(StopIteration) | ||
| self.new_requests_event.clear() | ||
|
|
||
| def process_request_output(self, request_output: RequestOutput) -> None: | ||
| """Process a request output from the engine.""" | ||
| self._finished_requests.put_nowait(request_output) | ||
|
|
||
| async def wait_for_new_requests(self): | ||
| await self.new_requests_event.wait() | ||
|
|
||
| def __aiter__(self): | ||
| return self | ||
|
|
||
| async def __anext__(self) -> RequestOutput: | ||
| result = await self._finished_requests.get() | ||
| # print("result of ", result) | ||
| if result is StopIteration: | ||
| raise StopAsyncIteration | ||
| return result | ||
|
|
||
|
|
||
| class Async_Engine: | ||
|
|
||
| """ | ||
| Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager | ||
| Background loop: inference reqs in waiting list (Listen) | ||
| Request Tracker: manage incoming requests and restore finished ones | ||
| Generate: exposed func for add new input and return finished ones | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| router_config, | ||
| engine_config, | ||
| start_engine_loop: bool = True, | ||
| ) -> None: | ||
| self.driver = Driver(router_config=router_config, engine_config=engine_config) | ||
| self.background_loop = None | ||
| self.start_engine_loop = start_engine_loop | ||
| self._request_tracker = RequestTracker() | ||
|
|
||
| def _step(self): | ||
| """ | ||
| Logic for handling requests | ||
| """ | ||
| request_outputs = self.driver.step() | ||
| if request_outputs is not None: | ||
| for request_output in request_outputs: | ||
| self._request_tracker.process_request_output(request_output) | ||
| self._request_tracker.add_stop() | ||
|
|
||
| def abort_request(self, request_id: str): | ||
| self.driver.abort(request_id) | ||
|
|
||
| def _has_requests_in_progress(self): | ||
| return self.driver.is_running() | ||
|
|
||
| async def run_loop_fwd(self): | ||
| has_requests_in_progress = self._has_requests_in_progress() | ||
| while True: | ||
| if not has_requests_in_progress: | ||
| await self._request_tracker.wait_for_new_requests() | ||
| self._step() | ||
| await asyncio.sleep(0) | ||
|
|
||
| @property | ||
| def is_running(self): | ||
| return self.background_loop is not None and not self.background_loop.done() | ||
|
|
||
| def start_background_loop(self): | ||
| if self.is_running: | ||
| raise RuntimeError("Background loop is already running.") | ||
|
|
||
| self._request_tracker.init_event() | ||
|
|
||
| self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd()) | ||
| self.background_loop = asyncio.shield(self.background_loop_unshielded) | ||
|
|
||
| async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams): | ||
| self.driver.add_input(request_id, prompt, sampling_params) | ||
| self._request_tracker.add_request(request_id) | ||
|
|
||
| async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): | ||
| """ | ||
| The only exposed func, adding new request and return a async generator that yields the existing results. | ||
| """ | ||
| try: | ||
| if not self.is_running: | ||
| self.start_background_loop() | ||
|
|
||
| await self.add_request(request_id, prompt, sampling_params) | ||
|
|
||
| async for request_output in self._request_tracker: | ||
| yield request_output | ||
|
|
||
| except (Exception, asyncio.CancelledError) as e: | ||
| # If there is an exception or coroutine is cancelled, abort the request. | ||
| self.abort_request(request_id) | ||
| raise e | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| from typing import List | ||
|
|
||
| from .dynamic_batching.io_struct import Batch, Req, RequestOutput | ||
| from .manager import DynamicBatchManager | ||
| from .tensor_parallel import TPInferEngine | ||
|
|
||
|
|
||
| class Async_DynamicBatchManager(DynamicBatchManager): | ||
| def __init__( | ||
| self, | ||
| tp_engine: TPInferEngine, | ||
| max_total_token_num: int, | ||
| batch_max_tokens: int, | ||
| model: str, | ||
| tokenizer=None, | ||
| eos_id=None, | ||
| log_stats=True, | ||
| log_stats_interval=10, | ||
| running_batch: Batch = None, | ||
| waiting_req_list: List = [], | ||
| ): | ||
| """ | ||
| Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager | ||
| max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) | ||
| batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests | ||
| running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine | ||
| eos_id : The end token of a seq | ||
| model: the model weight dir path, the app will load config, weights and tokenizer from this dir | ||
| log_stats : whether to log stats | ||
| log_stats_interval : log stats interval | ||
| running_batch : running batch | ||
| waiting_req_list : list of waiting requests, initialized before dynamic batch manager | ||
| """ | ||
| super().__init__( | ||
| tp_engine, | ||
| max_total_token_num, | ||
| batch_max_tokens, | ||
| model, | ||
| tokenizer, | ||
| eos_id, | ||
| log_stats, | ||
| log_stats_interval, | ||
| running_batch, | ||
| waiting_req_list, | ||
| ) | ||
|
|
||
| def _step(self): | ||
| """ | ||
| Logic for handling requests | ||
| """ | ||
| has_new_finished = False | ||
| if self.running_batch is None: | ||
| new_batch = self.req_queue.generate_new_batch(self.running_batch) | ||
| if new_batch is not None: | ||
| self.stats_tool.count_prompt_tokens(new_batch) | ||
| self.running_batch = new_batch | ||
| has_new_finished, outputs = self._prefill_batch(self.running_batch) | ||
| self._filter_runing_batch() | ||
| self.has_wait_tokens = 0 | ||
|
|
||
| else: | ||
| if self.has_wait_tokens < self.max_wait_tokens: | ||
| self.stats_tool.count_output_tokens(self.running_batch) | ||
| has_new_finished, outputs = self._decode_batch(self.running_batch) | ||
| self._filter_runing_batch() | ||
| self.has_wait_tokens += 1 | ||
|
|
||
| else: | ||
| new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) | ||
| if new_mini_batch is not None: | ||
| self.stats_tool.count_prompt_tokens(new_mini_batch) | ||
| has_new_finished, outputs = self._prefill_batch(new_mini_batch) | ||
| if not new_mini_batch.is_clear(): | ||
| self._merge_batch(self.running_batch, new_mini_batch) | ||
| self.running_batch.merge(new_mini_batch) | ||
| self.has_wait_tokens = 0 | ||
|
|
||
| else: | ||
| self.stats_tool.count_output_tokens(self.running_batch) | ||
| has_new_finished, outputs = self._decode_batch(self.running_batch) | ||
| self._filter_runing_batch() | ||
| self.has_wait_tokens += 1 | ||
|
|
||
| if has_new_finished: | ||
| return outputs | ||
| return None | ||
|
|
||
| def _prefill_batch(self, batch): | ||
| """ | ||
| For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. | ||
| """ | ||
| self._init_batch(batch) | ||
|
|
||
| # TODO: figure out if cache and batch id is needed | ||
| ans = self.engine._prefill_batch(batch.batch_id) | ||
| req_to_out_token_id = ans | ||
| self._add_token_id_to_req(batch, req_to_out_token_id) | ||
| has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) | ||
| outputs = self._handle_finish_req(batch, has_new_finished_req) | ||
| return has_new_finished_req, outputs | ||
| # delete finished reqs | ||
|
|
||
| def _decode_batch(self, batch: Batch): | ||
| """ | ||
| Decoding process | ||
| """ | ||
| ans = self.engine._decode_batch(batch.batch_id) | ||
| req_to_out_token_id = ans | ||
| self._add_token_id_to_req(batch, req_to_out_token_id) | ||
| has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) | ||
| outputs = self._handle_finish_req(batch, has_new_finished_req) | ||
| return has_new_finished_req, outputs | ||
|
|
||
| def _handle_finish_req(self, batch: Batch, has_new_finished_req): | ||
| if has_new_finished_req: | ||
| finished_reqs = batch.filter_finished() | ||
| if batch.is_clear(): | ||
| self._remove_batch(batch) | ||
| else: | ||
| self._filter_batch(batch) | ||
| return self._output_process(finished_reqs) | ||
| return None | ||
|
|
||
| def _output_process(self, finished_reqs: List[Req]): | ||
| """ | ||
| Process the output of a batch. | ||
| """ | ||
| outputs = [] | ||
| for req in finished_reqs: | ||
| output = self.tokenizer.decode(req.output_ids) | ||
| outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output)) | ||
| return outputs | ||
|
|
||
|
|
||
| def start_dynamic_batching(args, tp_engine, waiting_req_list): | ||
| try: | ||
| batch_manager = Async_DynamicBatchManager( | ||
| tp_engine=tp_engine, | ||
| max_total_token_num=args.max_total_token_num, | ||
| batch_max_tokens=args.batch_max_tokens, | ||
| eos_id=args.eos_id, | ||
| model=args.model, | ||
| log_stats=not args.disable_log_stats, | ||
| log_stats_interval=args.log_stats_interval, | ||
| waiting_req_list=waiting_req_list, | ||
| ) | ||
|
|
||
| except Exception: | ||
| raise Exception | ||
|
|
||
| return batch_manager |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| """ | ||
| Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue. | ||
|
|
||
| license: MIT, see LICENSE for more details. | ||
| """ | ||
|
|
||
| from transformers import AutoTokenizer | ||
|
|
||
| _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" | ||
|
|
||
|
|
||
| def get_tokenizer( | ||
| tokenizer=None, | ||
| tokenizer_name: str = "", | ||
| trust_remote_code: bool = False, | ||
| use_fast: bool = True, | ||
| ): | ||
| if tokenizer is not None: | ||
| tokenizer = tokenizer | ||
| else: | ||
| if "llama" in tokenizer_name.lower() and use_fast == True: | ||
| print( | ||
| "For some LLaMA-based models, initializing the fast tokenizer may " | ||
| "take a long time. To eliminate the initialization time, consider " | ||
| f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " | ||
| "tokenizer. This is done automatically in Colossalai." | ||
| ) | ||
|
|
||
| tokenizer_name = _FAST_LLAMA_TOKENIZER | ||
|
|
||
| try: | ||
| tokenizer = AutoTokenizer.from_pretrained( | ||
| tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code | ||
| ) | ||
| except TypeError: | ||
| use_fast = False | ||
| tokenizer = AutoTokenizer.from_pretrained( | ||
| tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code | ||
| ) | ||
| return tokenizer |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.