From 08938a7684806a012dfb0fd09beb75347a5fe285 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 17 Oct 2023 11:55:28 +0800 Subject: [PATCH 01/15] infer engine --- colossalai/inference/async_engine.py | 138 ++++++++++++++++++ .../test_dynamic_batching/config.yaml | 4 +- .../test_async_engine.py | 60 ++++++++ .../test_dynamic_batching/test_ray_dist.py | 10 +- 4 files changed, 207 insertions(+), 5 deletions(-) create mode 100644 colossalai/inference/async_engine.py create mode 100644 tests/test_infer/test_dynamic_batching/test_async_engine.py diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py new file mode 100644 index 000000000000..ab1107c19540 --- /dev/null +++ b/colossalai/inference/async_engine.py @@ -0,0 +1,138 @@ +import asyncio + +from colossalai.inference.dynamic_batching.ray_dist_init import Driver + +from .dynamic_batching.sampling_params import SamplingParams + + +class RequestTracker: + """ + A class for trace down all the requests, abstraction for async + """ + + def __init__(self) -> None: + self._requests: asyncio.Queue[str] = asyncio.Queue() + self._finished_requests: asyncio.Queue[str] = asyncio.Queue() + self.new_requests_event = None + + def __contains__(self, item): + return item in self._requests + + def init_event(self): + self.new_requests_event = asyncio.Event() + + def add_request(self, request_id: str): + """Add a request to be sent to the engine on the next background + loop iteration.""" + if request_id in self._requests: + raise KeyError(f"Request {request_id} already exists.") + + self._requests.put_nowait(request_id) + + self.new_requests_event.set() + + def abort_request(self, request_id: str, *, verbose: bool = False) -> None: + """Abort a request during next background loop iteration.""" + if verbose: + logger.info(f"Aborted request {request_id}.") + + self._finished_requests.put_nowait(request_id) + return + + async def wait_for_new_requests(self): + await self.new_requests_event.wait() + + +class Async_Engine: + + """ + loop: start listen + add req + remove req + generate--> return async generator + """ + + def __init__( + self, + driver: Driver = None, + start_engine_loop: bool = True, + ) -> None: + self.driver = driver + self.background_loop = None + self.start_engine_loop = start_engine_loop + self._request_tracker = None + + def _step(self): + """ + Logic for handling requests + """ + self.driver.step() + + def _has_requests_in_progress(self): + return self.driver.has_requests_in_progress() + + async def run_loop_fwd(self): + has_requests_in_progress = self._has_requests_in_progress() + while True: + if not has_requests_in_progress: + await self._request_tracker.wait_for_requests() + has_requests_in_progress = await self._step() + await asyncio.sleep(0) + + @property + def is_running(self): + return self.background_loop is not None and not self.background_loop.done() + + def start_background_loop(self): + if self.is_running: + raise RuntimeError("Background loop is already running.") + + self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd()) + self.background_loop = asyncio.shield(self.background_loop_unshielded) + + async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams): + self.driver.add_input(request_id, prompt, sampling_params) + self._request_tracker.add_request(request_id) + + async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): + """ + The only exposed func, adding new request and return a async generator that yields the existing results. + """ + try: + await self.add_request(request_id, prompt, sampling_params) + + stream = await self.driver.async_generate(request_id, prompt, sampling_params) + + async for request_output in stream: + yield request_output + + except (Exception, asyncio.CancelledError) as e: + # If there is an exception or coroutine is cancelled, abort the request. + self._request_tracker.abort_request(request_id) + raise e + + +# def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): +# results = ray.get([w.generate.remote(request_id, prompt, sampling_params) for w in self.workers]) +# text_res = results[0] # get any one of the copies +# return text_res + +# async def async_generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): +# all_outputs = [] +# for worker in self.workers: +# all_outputs.append(worker.generate.remote(request_id, prompt, sampling_params)) +# all_outputs = await asyncio.gather(*all_outputs) +# text_res = all_outputs[0]# get any one of the copies +# return text_res + +# def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): +# ray.get([w.add_input.remote(request_id, sampling_params, prompt) for w in self.workers]) + +# def abort(self,request_id: str): +# ray.get([w.abort.remote(request_id) for w in self.workers]) + +# def step(self): +# ray.get([w._step.remote() for w in self.workers]) + +# def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): +# ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) diff --git a/tests/test_infer/test_dynamic_batching/config.yaml b/tests/test_infer/test_dynamic_batching/config.yaml index c31ae8c5fadb..7af338b38df0 100644 --- a/tests/test_infer/test_dynamic_batching/config.yaml +++ b/tests/test_infer/test_dynamic_batching/config.yaml @@ -1,5 +1,5 @@ engine_config: - model: MODEL_PATH + model: "/mnt/vepfs/lczyh/models/llama-7b-hf" tensor_parallel_size: 2 max_batch_size: 4 max_input_len: 128 @@ -12,4 +12,4 @@ router_config: eos_id: 0 disable_log_stats: False log_stats_interval: 10 - model: MODEL_PATH + model: "/mnt/vepfs/lczyh/models/llama-7b-hf" diff --git a/tests/test_infer/test_dynamic_batching/test_async_engine.py b/tests/test_infer/test_dynamic_batching/test_async_engine.py new file mode 100644 index 000000000000..dd83846e7e6f --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_async_engine.py @@ -0,0 +1,60 @@ +import asyncio +import os +import uuid + +import pytest + +import colossalai +from colossalai.inference.dynamic_batching.ray_dist_init import Driver +from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +PATH = "config.yaml" + + +def run_ray_dist(path: str): + print(f"Using yaml file {path}") + if not os.path.exists(path): + raise FileNotFoundError(f"Invalid yaml file path {path}") + config = RayInitConfig.from_yaml_path(path) + router_config = config.router_config_data + engine_config = config.engine_config_data + model = engine_config.model + if model is None or not os.path.exists(model): + return + driver = Driver(router_config=router_config, engine_config=engine_config) + prompt = "Introduce some landmarks in Beijing" + + request_id = str(uuid.uuid4().hex) + sampling_params = SamplingParams() + print("sampling_params: ", sampling_params) + + async def get_result(request_id, prompt, sampling_params): + return await driver.async_generate(request_id, prompt, sampling_params) + + for test_async in [True, False]: + if test_async: + print("test_async: ", test_async) + result = asyncio.run(get_result(request_id, prompt, sampling_params)) + print("result: ", result) + else: + print("test_async: ", test_async) + result = driver.generate(request_id, prompt, sampling_params) + print("result: ", result) + + +def check_ray_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_ray_dist(PATH) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_ray_dist(): + spawn(check_ray_dist, 1) + + +if __name__ == "__main__": + test_ray_dist() diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index 4cf9881f41dc..dd83846e7e6f 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -2,15 +2,17 @@ import os import uuid +import pytest + +import colossalai from colossalai.inference.dynamic_batching.ray_dist_init import Driver from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -import colossalai -import pytest from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn PATH = "config.yaml" + def run_ray_dist(path: str): print(f"Using yaml file {path}") if not os.path.exists(path): @@ -25,8 +27,8 @@ def run_ray_dist(path: str): prompt = "Introduce some landmarks in Beijing" request_id = str(uuid.uuid4().hex) - sampling_params = SamplingParams() + print("sampling_params: ", sampling_params) async def get_result(request_id, prompt, sampling_params): return await driver.async_generate(request_id, prompt, sampling_params) @@ -41,6 +43,7 @@ async def get_result(request_id, prompt, sampling_params): result = driver.generate(request_id, prompt, sampling_params) print("result: ", result) + def check_ray_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_ray_dist(PATH) @@ -52,5 +55,6 @@ def check_ray_dist(rank, world_size, port): def test_ray_dist(): spawn(check_ray_dist, 1) + if __name__ == "__main__": test_ray_dist() From 4e7aa7c77a0b880953ddea3f2b8fa5399c4ed405 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 17 Oct 2023 12:09:57 +0800 Subject: [PATCH 02/15] infer engine --- colossalai/inference/async_engine.py | 9 ++++++--- .../test_dynamic_batching/test_async_engine.py | 8 ++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py index ab1107c19540..cdae2ffa621f 100644 --- a/colossalai/inference/async_engine.py +++ b/colossalai/inference/async_engine.py @@ -54,13 +54,14 @@ class Async_Engine: def __init__( self, - driver: Driver = None, + router_config, + engine_config, start_engine_loop: bool = True, ) -> None: - self.driver = driver + self.driver = Driver(router_config=router_config, engine_config=engine_config) self.background_loop = None self.start_engine_loop = start_engine_loop - self._request_tracker = None + self._request_tracker = RequestTracker() def _step(self): """ @@ -87,6 +88,8 @@ def start_background_loop(self): if self.is_running: raise RuntimeError("Background loop is already running.") + self._request_tracker.init_event() + self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd()) self.background_loop = asyncio.shield(self.background_loop_unshielded) diff --git a/tests/test_infer/test_dynamic_batching/test_async_engine.py b/tests/test_infer/test_dynamic_batching/test_async_engine.py index dd83846e7e6f..b12b79072ea6 100644 --- a/tests/test_infer/test_dynamic_batching/test_async_engine.py +++ b/tests/test_infer/test_dynamic_batching/test_async_engine.py @@ -44,7 +44,7 @@ async def get_result(request_id, prompt, sampling_params): print("result: ", result) -def check_ray_dist(rank, world_size, port): +def check_async_engine(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_ray_dist(PATH) @@ -52,9 +52,9 @@ def check_ray_dist(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() -def test_ray_dist(): - spawn(check_ray_dist, 1) +def test_async_engine(): + spawn(check_async_engine, 1) if __name__ == "__main__": - test_ray_dist() + test_async_engine() From 2e01e9e5f502966034e9c4c87cc3feb66e50c944 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 17 Oct 2023 13:51:52 +0800 Subject: [PATCH 03/15] test engine --- colossalai/inference/async_engine.py | 7 ++++-- .../dynamic_batching/ray_dist_init.py | 7 +++++- colossalai/inference/manager.py | 1 + .../test_async_engine.py | 25 ++++++++----------- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py index cdae2ffa621f..5d4f1c8dacdf 100644 --- a/colossalai/inference/async_engine.py +++ b/colossalai/inference/async_engine.py @@ -24,8 +24,8 @@ def init_event(self): def add_request(self, request_id: str): """Add a request to be sent to the engine on the next background loop iteration.""" - if request_id in self._requests: - raise KeyError(f"Request {request_id} already exists.") + # if request_id in self._requests: + # raise KeyError(f"Request {request_id} already exists.") self._requests.put_nowait(request_id) @@ -102,6 +102,9 @@ async def generate(self, request_id: str, prompt: str, sampling_params: Sampling The only exposed func, adding new request and return a async generator that yields the existing results. """ try: + if not self.is_running: + self.start_background_loop() + await self.add_request(request_id, prompt, sampling_params) stream = await self.driver.async_generate(request_id, prompt, sampling_params) diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index a40a00e2666c..01dcf38eeb8e 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -152,7 +152,12 @@ async def async_generate(self, request_id: str, prompt: str, sampling_params: Sa return text_res def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): - ray.get([w.add_input.remote(request_id, sampling_params, prompt) for w in self.workers]) + ray.get( + [ + w.add_input.remote(request_id=request_id, prompt=prompt, sampling_params=sampling_params) + for w in self.workers + ] + ) def abort(self, request_id: str): ray.get([w.abort.remote(request_id) for w in self.workers]) diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 30717a915e3b..f75a1c012a26 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -63,6 +63,7 @@ def add_input(self, request_id, sampling_params, prompts): """ Encode and Add new input to req queue. support one sequence input for now. """ + print("promptssssss", prompts) prompt_ids = self.tokenizer.encode(prompts) prompt_len = len(prompt_ids) if prompt_len > self.engine.max_input_len: diff --git a/tests/test_infer/test_dynamic_batching/test_async_engine.py b/tests/test_infer/test_dynamic_batching/test_async_engine.py index b12b79072ea6..6c3f498b351b 100644 --- a/tests/test_infer/test_dynamic_batching/test_async_engine.py +++ b/tests/test_infer/test_dynamic_batching/test_async_engine.py @@ -5,7 +5,7 @@ import pytest import colossalai -from colossalai.inference.dynamic_batching.ray_dist_init import Driver +from colossalai.inference.async_engine import Async_Engine from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig from colossalai.inference.dynamic_batching.sampling_params import SamplingParams from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn @@ -13,7 +13,7 @@ PATH = "config.yaml" -def run_ray_dist(path: str): +def run_async_engine(path: str): print(f"Using yaml file {path}") if not os.path.exists(path): raise FileNotFoundError(f"Invalid yaml file path {path}") @@ -23,30 +23,25 @@ def run_ray_dist(path: str): model = engine_config.model if model is None or not os.path.exists(model): return - driver = Driver(router_config=router_config, engine_config=engine_config) + engine = Async_Engine(router_config=router_config, engine_config=engine_config) + prompt = "Introduce some landmarks in Beijing" request_id = str(uuid.uuid4().hex) sampling_params = SamplingParams() - print("sampling_params: ", sampling_params) async def get_result(request_id, prompt, sampling_params): - return await driver.async_generate(request_id, prompt, sampling_params) - - for test_async in [True, False]: - if test_async: - print("test_async: ", test_async) - result = asyncio.run(get_result(request_id, prompt, sampling_params)) - print("result: ", result) - else: - print("test_async: ", test_async) - result = driver.generate(request_id, prompt, sampling_params) + results = engine.generate(request_id, prompt, sampling_params) + async for result in results: print("result: ", result) + return result + + asyncio.run(get_result(request_id, prompt, sampling_params)) def check_async_engine(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_ray_dist(PATH) + run_async_engine(PATH) @pytest.mark.dist From 8773b76269d5168cd3a3a62941879fff6db3eb3e Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 17 Oct 2023 15:26:10 +0800 Subject: [PATCH 04/15] test engine --- colossalai/inference/async_engine.py | 6 +++--- colossalai/inference/dynamic_batching/ray_dist_init.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py index 5d4f1c8dacdf..d7404491ae5b 100644 --- a/colossalai/inference/async_engine.py +++ b/colossalai/inference/async_engine.py @@ -70,14 +70,14 @@ def _step(self): self.driver.step() def _has_requests_in_progress(self): - return self.driver.has_requests_in_progress() + return self.driver.is_running() async def run_loop_fwd(self): has_requests_in_progress = self._has_requests_in_progress() while True: if not has_requests_in_progress: - await self._request_tracker.wait_for_requests() - has_requests_in_progress = await self._step() + await self._request_tracker.wait_for_new_requests() + self._step() await asyncio.sleep(0) @property diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index b75bf12f7e06..09ec09c42704 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -100,7 +100,7 @@ def step(self): def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt) - + def is_running(self): return self.start_dynamic_batching.is_running() @@ -166,11 +166,11 @@ def abort(self, request_id: str): ray.get([w.abort.remote(request_id) for w in self.workers]) def step(self): - ray.get([w._step.remote() for w in self.workers]) + ray.get([w.step.remote() for w in self.workers]) def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) - + def is_running(self): results = ray.get([w.is_running.remote() for w in self.workers]) return any(results) From eec84f82ba84cb59dec52c689109c3936a774cc3 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 17 Oct 2023 17:33:48 +0800 Subject: [PATCH 05/15] new manager --- colossalai/inference/async_engine.py | 16 +- .../dynamic_batching/async_manager.py | 288 ++++++++++++++++++ .../inference/dynamic_batching/io_struct.py | 23 ++ 3 files changed, 323 insertions(+), 4 deletions(-) create mode 100644 colossalai/inference/dynamic_batching/async_manager.py diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py index d7404491ae5b..6b1842f3bcf3 100644 --- a/colossalai/inference/async_engine.py +++ b/colossalai/inference/async_engine.py @@ -2,6 +2,7 @@ from colossalai.inference.dynamic_batching.ray_dist_init import Driver +from .dynamic_batching.io_struct import RequestOutput from .dynamic_batching.sampling_params import SamplingParams @@ -12,7 +13,7 @@ class RequestTracker: def __init__(self) -> None: self._requests: asyncio.Queue[str] = asyncio.Queue() - self._finished_requests: asyncio.Queue[str] = asyncio.Queue() + self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue() self.new_requests_event = None def __contains__(self, item): @@ -28,7 +29,6 @@ def add_request(self, request_id: str): # raise KeyError(f"Request {request_id} already exists.") self._requests.put_nowait(request_id) - self.new_requests_event.set() def abort_request(self, request_id: str, *, verbose: bool = False) -> None: @@ -36,9 +36,12 @@ def abort_request(self, request_id: str, *, verbose: bool = False) -> None: if verbose: logger.info(f"Aborted request {request_id}.") - self._finished_requests.put_nowait(request_id) return + def process_request_output(self, request_output: RequestOutput) -> None: + """Process a request output from the engine.""" + self._finished_requests.put_nowait(request_output) + async def wait_for_new_requests(self): await self.new_requests_event.wait() @@ -67,7 +70,10 @@ def _step(self): """ Logic for handling requests """ - self.driver.step() + request_outputs = self.driver.step() + if request_outputs is not None: + for request_output in request_outputs: + self._request_tracker.process_request_output(request_output) def _has_requests_in_progress(self): return self.driver.is_running() @@ -106,6 +112,8 @@ async def generate(self, request_id: str, prompt: str, sampling_params: Sampling self.start_background_loop() await self.add_request(request_id, prompt, sampling_params) + for request_output in request_outputs: + self._request_tracker.process_request_output(request_output, verbose=self.log_requests) stream = await self.driver.async_generate(request_id, prompt, sampling_params) diff --git a/colossalai/inference/dynamic_batching/async_manager.py b/colossalai/inference/dynamic_batching/async_manager.py new file mode 100644 index 000000000000..5101f046012c --- /dev/null +++ b/colossalai/inference/dynamic_batching/async_manager.py @@ -0,0 +1,288 @@ +import time +from typing import List + +from .dynamic_batching.get_tokenizer import get_tokenizer +from .dynamic_batching.infer_batch import InferBatch +from .dynamic_batching.io_struct import Batch, Req, RequestOutput +from .dynamic_batching.req_queue import ReqQueue +from .dynamic_batching.sampling_params import SamplingParams +from .dynamic_batching.stats import Stats +from .tensor_parallel import TPInferEngine + + +class Async_DynamicBatchManager: + def __init__( + self, + tp_engine: TPInferEngine, + max_total_token_num, + batch_max_tokens, + eos_id, + model, + log_stats=True, + log_stats_interval=10, + running_batch: Batch = None, + waiting_req_list: List = [], + ): + """ + Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager + max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) + batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests + running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine + eos_id : The end token of a seq + model: the model weight dir path, the app will load config, weights and tokenizer from this dir + log_stats : whether to log stats + log_stats_interval : log stats interval + running_batch : running batch + waiting_req_list : list of waiting requests, initialized before dynamic batch manager + """ + self.engine = tp_engine + self.max_total_token_num = max_total_token_num + running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2 + self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list) + # all the inputs should be put into req_queue: waiting req list + + self.running_batch: Batch = running_batch + self.eos_id = eos_id + self.has_wait_tokens = 0 + self.max_wait_tokens = 10 + self.model = model + + self.stats_tool = Stats(log_stats, log_stats_interval) + self.mem_usage_interval = log_stats_interval * 2 + self.tokenizer = get_tokenizer(tokenizer_name=self.model) + + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str): + """ + Add new request to req queue, during initialization all requests are held in waiting list. + """ + req = Req(request_id, prompt_ids, sampling_params, prompts) + self.req_queue.append(req) + return + + def add_input(self, request_id, sampling_params, prompts): + """ + Encode and Add new input to req queue. support one sequence input for now. + """ + print("promptssssss", prompts) + prompt_ids = self.tokenizer.encode(prompts) + prompt_len = len(prompt_ids) + if prompt_len > self.engine.max_input_len: + raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}") + sampling_params.stop_sentences_to_token_ids(self.tokenizer) + self.add_req(prompt_ids, sampling_params, request_id, prompts) + return + + def abort(self, request_id): + if self.running_batch is not None: + for req in self.running_batch.reqs: + if req.request_id == request_id: + req.has_generate_finished = True + req.aborted = True + for req in self.req_queue.waiting_req_list: + if req.request_id == request_id: + req.has_generate_finished = True + req.aborted = True + return + + def loop_for_fwd(self): + """ + The main loop for a dynamic batching process. + """ + counter_count = 0 + # self.running_batch is not None or self.req_queue.waiting_req_list + while self.running_batch is not None or self.req_queue.waiting_req_list: + self._step() + counter_count += 1 + if self.running_batch is not None: + if counter_count % self.mem_usage_interval == 0: + print( + "current batch size:", + len(self.running_batch.reqs), + "token used ratio:", + self.running_batch.calcu_used_tokens() / self.max_total_token_num, + ) + self.stats_tool.print_stats() + + if self.running_batch is None: + time.sleep(0.1) # 10ms + + def _step(self): + """ + Logic for handling requests + """ + + if self.running_batch is None: + new_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_batch is not None: + self.stats_tool.count_prompt_tokens(new_batch) + self.running_batch = new_batch + has_new_finished, outputs = self._prefill_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens = 0 + + if self.has_wait_tokens < self.max_wait_tokens: + self.stats_tool.count_output_tokens(self.running_batch) + has_new_finished, outputs = self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + else: + new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_mini_batch is not None: + self.stats_tool.count_prompt_tokens(new_mini_batch) + has_new_finished, outputs = self._prefill_batch(new_mini_batch) + if not new_mini_batch.is_clear(): + self._merge_batch(self.running_batch, new_mini_batch) + self.running_batch.merge(new_mini_batch) + self.has_wait_tokens = 0 + + else: + self.stats_tool.count_output_tokens(self.running_batch) + has_new_finished, outputs = self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + + if has_new_finished: + return outputs + + return None + + def _init_batch(self, batch: Batch, dtype="fp16"): + reqs = [r.to_rpc_obj() for r in batch.reqs] + batch_id = batch.batch_id + + import torch + + if dtype == "fp16": + dtype = torch.float16 + else: + assert False, "error dtype" + + batch_data = InferBatch.init_batch( + batch_id, + reqs, + dtype, + torch.cuda.current_device(), + self.engine.cache_manager, + self.engine.model.config.vocab_size, + self.engine.max_input_len + self.engine.max_output_len, + ) + self.engine.cache[batch_id] = batch_data + + def _prefill_batch(self, batch): + """ + For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. + """ + self._init_batch(batch) + + # TODO: figure out if cache and batch id is needed + ans = self.engine._prefill_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id) + outputs = self._handle_finish_req(batch, has_new_finished_req) + return has_new_finished_req, outputs + # delete finished reqs + + def _decode_batch(self, batch: Batch): + """ + Decoding process + """ + ans = self.engine._decode_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id) + outputs = self._handle_finish_req(batch, has_new_finished_req) + return has_new_finished_req, outputs + + def _filter_batch(self, batch: Batch): + batch_id = batch.batch_id + req_id_list = [r.request_id for r in batch.reqs] + batch = self.engine.cache.pop(batch_id) + filter_batch = batch.filter(req_id_list) + del batch + self.engine.cache[batch_id] = filter_batch + + def _merge_batch(self, batch1, batch2): + """ + Merge new mini batch into running batch. + """ + batch1 = self.engine.cache.pop(batch1.batch_id) + batch2 = self.engine.cache.pop(batch2.batch_id) + + m_batch = InferBatch.merge(batch1, batch2) + self.engine.cache[batch1.batch_id] = m_batch + del batch1 + del batch2 + + def _remove_batch(self, batch): + """ + Remove finished batch. + """ + batch = self.engine.cache.pop(batch.batch_id) + batch.free_self() + del batch + + def _handle_finish_req(self, batch: Batch, has_new_finished_req): + if has_new_finished_req: + finished_reqs = batch.filter_finished() + if batch.is_clear(): + self._remove_batch(batch) + else: + self._filter_batch(batch) + return self._output_process(finished_reqs) + return None + + def _filter_runing_batch(self): + if self.running_batch is not None and self.running_batch.is_clear(): + self.running_batch = None + + def _add_token_id_to_req(self, batch: Batch, req_ans): + for req_id, (new_token_id, new_gen_metadata) in req_ans.items(): + req = batch.id_to_reqs[req_id] + req.output_ids.append(new_token_id) + req.output_metadata_list.append(new_gen_metadata) + return + + def _output_process(self, finished_reqs: List[Req]): + """ + Process the output of a batch. + """ + outputs = [] + for req in finished_reqs: + output = self.tokenizer.decode(req.output_ids) + outputs.append(RequestOutput(req.request_id, req.prompt_ids, req.prompts, output)) + return outputs + + def clean_up(self): + # this logic should be implemented in the future. + pass + + def generate(self, prompts, sampling_params, request_id): + """ + Generate the output of a request. + """ + self.add_input(request_id, sampling_params, prompts) + return self.loop_for_fwd() + + def is_running(self): + return self.running_batch is not None or self.req_queue.waiting_req_list + + +def start_dynamic_batching(args, tp_engine, waiting_req_list): + try: + batch_manager = DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + model=args.model, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + except Exception: + batch_manager.clean_up() + raise + + return batch_manager diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 9faaad6f111e..1475e725e773 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -156,3 +156,26 @@ def __init__(self): class AbortReq: def __init__(self, req_id): self.req_id = req_id + + +class RequestOutput: + """The output data of a request to the LLM. + + Args: + request_id: The unique ID of the request. + prompt: The prompt string of the request. + prompt_token_ids: The token IDs of the prompt. + outputs: The output sequences of the request. + """ + + def __init__( + self, + request_id: str, + prompt: str, + prompt_token_ids: List[int], + outputs, + ) -> None: + self.request_id = request_id + self.prompt = prompt + self.prompt_token_ids = prompt_token_ids + self.outputs = outputs From 32520eed57e86da6ad604c0cc3789141bfe817ae Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Tue, 17 Oct 2023 17:46:48 +0800 Subject: [PATCH 06/15] change step --- .../dynamic_batching/ray_dist_init.py | 67 +++++++++---------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index 09ec09c42704..07d3ee5dc455 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -16,6 +16,7 @@ from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.shardformer import ShardConfig from colossalai.testing import free_port +from colossalai.inference.dynamic_batching.io_struct import RequestOutput ray_serve_logger = logging.getLogger("ray.serve") @@ -76,18 +77,12 @@ def setup(self, world_size, rank, port): return True - def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> str: - ray_serve_logger.info(f"text: {prompt}") + # def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> List[str]: + # ray_serve_logger.info(f"text: {prompt}") - results_generator = self.start_dynamic_batching.generate(prompt, sampling_params, request_id) - - final_output = None - for request_output in results_generator: - final_output = request_output - - assert final_output is not None - ray_serve_logger.info(f"Generated text: {final_output}") - return final_output + # final_outputs = self.start_dynamic_batching.generate(prompt, sampling_params, request_id) + + # return final_outputs def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): self.start_dynamic_batching.add_input(request_id, sampling_params, prompt) @@ -95,12 +90,13 @@ def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParam def abort(self, request_id: str): self.start_dynamic_batching.abort(request_id) - def step(self): - self.start_dynamic_batching._step() + def step(self) -> List[RequestOutput]: + return self.start_dynamic_batching._step() + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt) - + def is_running(self): return self.start_dynamic_batching.is_running() @@ -141,36 +137,39 @@ def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClas _ = ray.get(init_rets) # set batch wait delay in seconds and maximum number of sequences in a batch - def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): - results = ray.get([w.generate.remote(request_id, prompt, sampling_params) for w in self.workers]) - text_res = results[0] # get any one of the copies - return text_res - - async def async_generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): - all_outputs = [] - for worker in self.workers: - all_outputs.append(worker.generate.remote(request_id, prompt, sampling_params)) - all_outputs = await asyncio.gather(*all_outputs) - text_res = all_outputs[0] # get any one of the copies - return text_res + # def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): + # results = ray.get([w.generate.remote(request_id, prompt, sampling_params) for w in self.workers]) + # text_res = results[0] # get any one of the copies + # return text_res + + # async def async_generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): + + # async def async_get(worker, request_id, prompt, sampling_params): + # out = await worker.generate_stream.remote(request_id, prompt, sampling_params) + # print("async_get out: ", out) + # return out + + # all_outputs = [] + # for worker in self.workers: + # all_outputs.append() + # all_outputs = await asyncio.gather(*all_outputs) + # text_res = all_outputs[0] # get any one of the copies + # return text_res def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): - ray.get( - [ - w.add_input.remote(request_id=request_id, prompt=prompt, sampling_params=sampling_params) - for w in self.workers - ] - ) + ray.get([w.add_input.remote(request_id, sampling_params, prompt) for w in self.workers]) def abort(self, request_id: str): ray.get([w.abort.remote(request_id) for w in self.workers]) def step(self): - ray.get([w.step.remote() for w in self.workers]) + results = ray.get([w._step.remote() for w in self.workers]) + outputs = results[0] # get any one of the copies + return outputs def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) - + def is_running(self): results = ray.get([w.is_running.remote() for w in self.workers]) return any(results) From 75afb9bae54ddd8dbc1be443c3d5f174deef6d26 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 17 Oct 2023 17:52:25 +0800 Subject: [PATCH 07/15] add --- colossalai/inference/async_engine.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py index 6b1842f3bcf3..467d976831e2 100644 --- a/colossalai/inference/async_engine.py +++ b/colossalai/inference/async_engine.py @@ -45,6 +45,13 @@ def process_request_output(self, request_output: RequestOutput) -> None: async def wait_for_new_requests(self): await self.new_requests_event.wait() + def __aiter__(self): + return self + + async def __anext__(self) -> RequestOutput: + result = await self._finished_requests.get() + return result + class Async_Engine: @@ -112,12 +119,8 @@ async def generate(self, request_id: str, prompt: str, sampling_params: Sampling self.start_background_loop() await self.add_request(request_id, prompt, sampling_params) - for request_output in request_outputs: - self._request_tracker.process_request_output(request_output, verbose=self.log_requests) - - stream = await self.driver.async_generate(request_id, prompt, sampling_params) - async for request_output in stream: + async for request_output in self._request_tracker: yield request_output except (Exception, asyncio.CancelledError) as e: From 1b5d4f1213dcd5a971a8b74448ab78a2086e8629 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 17 Oct 2023 18:23:22 +0800 Subject: [PATCH 08/15] test --- colossalai/inference/async_engine.py | 2 ++ .../{dynamic_batching => }/async_manager.py | 4 ++-- .../inference/dynamic_batching/io_struct.py | 8 +++++++ .../dynamic_batching/ray_dist_init.py | 22 +++++++++---------- .../test_async_engine.py | 18 +++++++++------ 5 files changed, 33 insertions(+), 21 deletions(-) rename colossalai/inference/{dynamic_batching => }/async_manager.py (98%) diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py index 467d976831e2..29165e2b2619 100644 --- a/colossalai/inference/async_engine.py +++ b/colossalai/inference/async_engine.py @@ -50,6 +50,7 @@ def __aiter__(self): async def __anext__(self) -> RequestOutput: result = await self._finished_requests.get() + print("result of ", result) return result @@ -80,6 +81,7 @@ def _step(self): request_outputs = self.driver.step() if request_outputs is not None: for request_output in request_outputs: + print(request_output) self._request_tracker.process_request_output(request_output) def _has_requests_in_progress(self): diff --git a/colossalai/inference/dynamic_batching/async_manager.py b/colossalai/inference/async_manager.py similarity index 98% rename from colossalai/inference/dynamic_batching/async_manager.py rename to colossalai/inference/async_manager.py index 5101f046012c..e460fd860744 100644 --- a/colossalai/inference/dynamic_batching/async_manager.py +++ b/colossalai/inference/async_manager.py @@ -59,7 +59,7 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques self.req_queue.append(req) return - def add_input(self, request_id, sampling_params, prompts): + def add_input(self, request_id, prompts, sampling_params): """ Encode and Add new input to req queue. support one sequence input for now. """ @@ -270,7 +270,7 @@ def is_running(self): def start_dynamic_batching(args, tp_engine, waiting_req_list): try: - batch_manager = DynamicBatchManager( + batch_manager = Async_DynamicBatchManager( tp_engine=tp_engine, max_total_token_num=args.max_total_token_num, batch_max_tokens=args.batch_max_tokens, diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 1475e725e773..0fdb1141f038 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -179,3 +179,11 @@ def __init__( self.prompt = prompt self.prompt_token_ids = prompt_token_ids self.outputs = outputs + + def __repr__(self) -> str: + return ( + f"RequestOutput(request_id={self.request_id}, " + f"prompt={self.prompt!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"outputs={self.outputs}, " + ) diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index 07d3ee5dc455..3d4171f0a917 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -1,4 +1,3 @@ -import asyncio import logging import os from typing import List @@ -9,14 +8,14 @@ from transformers import AutoModelForCausalLM import colossalai +from colossalai.inference.async_manager import start_dynamic_batching from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer +from colossalai.inference.dynamic_batching.io_struct import RequestOutput from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -from colossalai.inference.manager import start_dynamic_batching from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.shardformer import ShardConfig from colossalai.testing import free_port -from colossalai.inference.dynamic_batching.io_struct import RequestOutput ray_serve_logger = logging.getLogger("ray.serve") @@ -81,22 +80,21 @@ def setup(self, world_size, rank, port): # ray_serve_logger.info(f"text: {prompt}") # final_outputs = self.start_dynamic_batching.generate(prompt, sampling_params, request_id) - + # return final_outputs def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): - self.start_dynamic_batching.add_input(request_id, sampling_params, prompt) + self.start_dynamic_batching.add_input(request_id, prompt, sampling_params) def abort(self, request_id: str): self.start_dynamic_batching.abort(request_id) def step(self) -> List[RequestOutput]: return self.start_dynamic_batching._step() - def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt) - + def is_running(self): return self.start_dynamic_batching.is_running() @@ -143,12 +141,12 @@ def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClas # return text_res # async def async_generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): - - # async def async_get(worker, request_id, prompt, sampling_params): + + # async def async_get(worker, request_id, prompt, sampling_params): # out = await worker.generate_stream.remote(request_id, prompt, sampling_params) # print("async_get out: ", out) # return out - + # all_outputs = [] # for worker in self.workers: # all_outputs.append() @@ -157,7 +155,7 @@ def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClas # return text_res def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): - ray.get([w.add_input.remote(request_id, sampling_params, prompt) for w in self.workers]) + ray.get([w.add_input.remote(request_id, prompt, sampling_params) for w in self.workers]) def abort(self, request_id: str): ray.get([w.abort.remote(request_id) for w in self.workers]) @@ -169,7 +167,7 @@ def step(self): def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) - + def is_running(self): results = ray.get([w.is_running.remote() for w in self.workers]) return any(results) diff --git a/tests/test_infer/test_dynamic_batching/test_async_engine.py b/tests/test_infer/test_dynamic_batching/test_async_engine.py index 6c3f498b351b..62535d767958 100644 --- a/tests/test_infer/test_dynamic_batching/test_async_engine.py +++ b/tests/test_infer/test_dynamic_batching/test_async_engine.py @@ -27,16 +27,20 @@ def run_async_engine(path: str): prompt = "Introduce some landmarks in Beijing" - request_id = str(uuid.uuid4().hex) sampling_params = SamplingParams() + asyncio.run(asy_for_loop_test(engine, prompt, sampling_params)) + + +async def get_result(engine, prompt, sampling_params): + request_id = str(uuid.uuid4().hex) + results = engine.generate(request_id, prompt, sampling_params) + async for result in results: + print("result: ", result) - async def get_result(request_id, prompt, sampling_params): - results = engine.generate(request_id, prompt, sampling_params) - async for result in results: - print("result: ", result) - return result - asyncio.run(get_result(request_id, prompt, sampling_params)) +async def asy_for_loop_test(engine, prompt, sampling_params): + for i in range(1): + await get_result(engine, prompt, sampling_params) def check_async_engine(rank, world_size, port): From 09a56970cd8d732843bfe7d9960e586b64dd4d80 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 18 Oct 2023 10:57:41 +0800 Subject: [PATCH 09/15] fix --- colossalai/inference/async_engine.py | 2 ++ colossalai/inference/dynamic_batching/ray_dist_init.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py index 29165e2b2619..1c5f3f62877d 100644 --- a/colossalai/inference/async_engine.py +++ b/colossalai/inference/async_engine.py @@ -79,6 +79,7 @@ def _step(self): Logic for handling requests """ request_outputs = self.driver.step() + print(request_outputs) if request_outputs is not None: for request_output in request_outputs: print(request_output) @@ -107,6 +108,7 @@ def start_background_loop(self): self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd()) self.background_loop = asyncio.shield(self.background_loop_unshielded) + print("start successfully") async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams): self.driver.add_input(request_id, prompt, sampling_params) diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index 3d4171f0a917..fc150c2d53d4 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -161,7 +161,7 @@ def abort(self, request_id: str): ray.get([w.abort.remote(request_id) for w in self.workers]) def step(self): - results = ray.get([w._step.remote() for w in self.workers]) + results = ray.get([w.step.remote() for w in self.workers]) outputs = results[0] # get any one of the copies return outputs From 4fc9ea67e9e887a6af896ca500c679c62fb9a2eb Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 18 Oct 2023 12:14:05 +0800 Subject: [PATCH 10/15] fix --- colossalai/inference/async_engine.py | 37 ++++--------------- colossalai/inference/async_manager.py | 23 ------------ .../test_async_engine.py | 13 ++++--- 3 files changed, 14 insertions(+), 59 deletions(-) diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py index 1c5f3f62877d..b59589568f3e 100644 --- a/colossalai/inference/async_engine.py +++ b/colossalai/inference/async_engine.py @@ -38,6 +38,9 @@ def abort_request(self, request_id: str, *, verbose: bool = False) -> None: return + def add_stop(self): + self._finished_requests.put_nowait(StopIteration) + def process_request_output(self, request_output: RequestOutput) -> None: """Process a request output from the engine.""" self._finished_requests.put_nowait(request_output) @@ -50,7 +53,9 @@ def __aiter__(self): async def __anext__(self) -> RequestOutput: result = await self._finished_requests.get() - print("result of ", result) + # print("result of ", result) + if result is StopIteration: + raise StopAsyncIteration return result @@ -79,11 +84,10 @@ def _step(self): Logic for handling requests """ request_outputs = self.driver.step() - print(request_outputs) if request_outputs is not None: for request_output in request_outputs: - print(request_output) self._request_tracker.process_request_output(request_output) + self._request_tracker.add_stop() def _has_requests_in_progress(self): return self.driver.is_running() @@ -108,7 +112,6 @@ def start_background_loop(self): self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd()) self.background_loop = asyncio.shield(self.background_loop_unshielded) - print("start successfully") async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams): self.driver.add_input(request_id, prompt, sampling_params) @@ -131,29 +134,3 @@ async def generate(self, request_id: str, prompt: str, sampling_params: Sampling # If there is an exception or coroutine is cancelled, abort the request. self._request_tracker.abort_request(request_id) raise e - - -# def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): -# results = ray.get([w.generate.remote(request_id, prompt, sampling_params) for w in self.workers]) -# text_res = results[0] # get any one of the copies -# return text_res - -# async def async_generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): -# all_outputs = [] -# for worker in self.workers: -# all_outputs.append(worker.generate.remote(request_id, prompt, sampling_params)) -# all_outputs = await asyncio.gather(*all_outputs) -# text_res = all_outputs[0]# get any one of the copies -# return text_res - -# def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): -# ray.get([w.add_input.remote(request_id, sampling_params, prompt) for w in self.workers]) - -# def abort(self,request_id: str): -# ray.get([w.abort.remote(request_id) for w in self.workers]) - -# def step(self): -# ray.get([w._step.remote() for w in self.workers]) - -# def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): -# ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) diff --git a/colossalai/inference/async_manager.py b/colossalai/inference/async_manager.py index e460fd860744..95bfe98e4fde 100644 --- a/colossalai/inference/async_manager.py +++ b/colossalai/inference/async_manager.py @@ -1,4 +1,3 @@ -import time from typing import List from .dynamic_batching.get_tokenizer import get_tokenizer @@ -84,28 +83,6 @@ def abort(self, request_id): req.aborted = True return - def loop_for_fwd(self): - """ - The main loop for a dynamic batching process. - """ - counter_count = 0 - # self.running_batch is not None or self.req_queue.waiting_req_list - while self.running_batch is not None or self.req_queue.waiting_req_list: - self._step() - counter_count += 1 - if self.running_batch is not None: - if counter_count % self.mem_usage_interval == 0: - print( - "current batch size:", - len(self.running_batch.reqs), - "token used ratio:", - self.running_batch.calcu_used_tokens() / self.max_total_token_num, - ) - self.stats_tool.print_stats() - - if self.running_batch is None: - time.sleep(0.1) # 10ms - def _step(self): """ Logic for handling requests diff --git a/tests/test_infer/test_dynamic_batching/test_async_engine.py b/tests/test_infer/test_dynamic_batching/test_async_engine.py index 62535d767958..fd3880b04957 100644 --- a/tests/test_infer/test_dynamic_batching/test_async_engine.py +++ b/tests/test_infer/test_dynamic_batching/test_async_engine.py @@ -18,17 +18,14 @@ def run_async_engine(path: str): if not os.path.exists(path): raise FileNotFoundError(f"Invalid yaml file path {path}") config = RayInitConfig.from_yaml_path(path) - router_config = config.router_config_data engine_config = config.engine_config_data model = engine_config.model if model is None or not os.path.exists(model): return - engine = Async_Engine(router_config=router_config, engine_config=engine_config) prompt = "Introduce some landmarks in Beijing" - sampling_params = SamplingParams() - asyncio.run(asy_for_loop_test(engine, prompt, sampling_params)) + asyncio.run(asy_for_loop_test(config, prompt, sampling_params)) async def get_result(engine, prompt, sampling_params): @@ -38,8 +35,12 @@ async def get_result(engine, prompt, sampling_params): print("result: ", result) -async def asy_for_loop_test(engine, prompt, sampling_params): - for i in range(1): +async def asy_for_loop_test(config, prompt, sampling_params): + router_config = config.router_config_data + engine_config = config.engine_config_data + engine = Async_Engine(router_config=router_config, engine_config=engine_config) + for i in range(10): + print("in for loop", i) await get_result(engine, prompt, sampling_params) From 6e4ded03ead2a6f0331cbd7b1ee6e3f44ced75c1 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 18 Oct 2023 12:22:25 +0800 Subject: [PATCH 11/15] finish test --- colossalai/inference/async_engine.py | 1 - colossalai/inference/async_manager.py | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py index b59589568f3e..98af9df6cc8d 100644 --- a/colossalai/inference/async_engine.py +++ b/colossalai/inference/async_engine.py @@ -35,7 +35,6 @@ def abort_request(self, request_id: str, *, verbose: bool = False) -> None: """Abort a request during next background loop iteration.""" if verbose: logger.info(f"Aborted request {request_id}.") - return def add_stop(self): diff --git a/colossalai/inference/async_manager.py b/colossalai/inference/async_manager.py index 95bfe98e4fde..d43808714edf 100644 --- a/colossalai/inference/async_manager.py +++ b/colossalai/inference/async_manager.py @@ -62,7 +62,7 @@ def add_input(self, request_id, prompts, sampling_params): """ Encode and Add new input to req queue. support one sequence input for now. """ - print("promptssssss", prompts) + print("prompt", prompts) prompt_ids = self.tokenizer.encode(prompts) prompt_len = len(prompt_ids) if prompt_len > self.engine.max_input_len: @@ -81,7 +81,6 @@ def abort(self, request_id): if req.request_id == request_id: req.has_generate_finished = True req.aborted = True - return def _step(self): """ @@ -227,7 +226,7 @@ def _output_process(self, finished_reqs: List[Req]): outputs = [] for req in finished_reqs: output = self.tokenizer.decode(req.output_ids) - outputs.append(RequestOutput(req.request_id, req.prompt_ids, req.prompts, output)) + outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output)) return outputs def clean_up(self): From 59d18f80f3b436d61bc185503c262684771af878 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 18 Oct 2023 13:45:43 +0800 Subject: [PATCH 12/15] finish test --- colossalai/inference/async_manager.py | 35 ++++++++++++++------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/colossalai/inference/async_manager.py b/colossalai/inference/async_manager.py index d43808714edf..ec3047c81b9a 100644 --- a/colossalai/inference/async_manager.py +++ b/colossalai/inference/async_manager.py @@ -86,7 +86,7 @@ def _step(self): """ Logic for handling requests """ - + has_new_finished = False if self.running_batch is None: new_batch = self.req_queue.generate_new_batch(self.running_batch) if new_batch is not None: @@ -96,30 +96,31 @@ def _step(self): self._filter_runing_batch() self.has_wait_tokens = 0 - if self.has_wait_tokens < self.max_wait_tokens: - self.stats_tool.count_output_tokens(self.running_batch) - has_new_finished, outputs = self._decode_batch(self.running_batch) - self._filter_runing_batch() - self.has_wait_tokens += 1 else: - new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) - if new_mini_batch is not None: - self.stats_tool.count_prompt_tokens(new_mini_batch) - has_new_finished, outputs = self._prefill_batch(new_mini_batch) - if not new_mini_batch.is_clear(): - self._merge_batch(self.running_batch, new_mini_batch) - self.running_batch.merge(new_mini_batch) - self.has_wait_tokens = 0 - - else: + if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) has_new_finished, outputs = self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 + else: + new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_mini_batch is not None: + self.stats_tool.count_prompt_tokens(new_mini_batch) + has_new_finished, outputs = self._prefill_batch(new_mini_batch) + if not new_mini_batch.is_clear(): + self._merge_batch(self.running_batch, new_mini_batch) + self.running_batch.merge(new_mini_batch) + self.has_wait_tokens = 0 + + else: + self.stats_tool.count_output_tokens(self.running_batch) + has_new_finished, outputs = self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + if has_new_finished: return outputs - return None def _init_batch(self, batch: Batch, dtype="fp16"): From 9f0a7bc6a84cf5d0294b7a6157e93d6889574630 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 18 Oct 2023 15:50:37 +0800 Subject: [PATCH 13/15] finish test --- colossalai/inference/async_manager.py | 141 ++---------------- .../dynamic_batching/get_tokenizer.py | 7 +- colossalai/inference/manager.py | 8 +- .../test_dynamic_batching/config.yaml | 4 +- .../test_async_engine.py | 4 +- 5 files changed, 27 insertions(+), 137 deletions(-) diff --git a/colossalai/inference/async_manager.py b/colossalai/inference/async_manager.py index ec3047c81b9a..78d11b1caa44 100644 --- a/colossalai/inference/async_manager.py +++ b/colossalai/inference/async_manager.py @@ -1,15 +1,11 @@ from typing import List -from .dynamic_batching.get_tokenizer import get_tokenizer -from .dynamic_batching.infer_batch import InferBatch from .dynamic_batching.io_struct import Batch, Req, RequestOutput -from .dynamic_batching.req_queue import ReqQueue -from .dynamic_batching.sampling_params import SamplingParams -from .dynamic_batching.stats import Stats +from .manager import DynamicBatchManager from .tensor_parallel import TPInferEngine -class Async_DynamicBatchManager: +class Async_DynamicBatchManager(DynamicBatchManager): def __init__( self, tp_engine: TPInferEngine, @@ -34,53 +30,17 @@ def __init__( running_batch : running batch waiting_req_list : list of waiting requests, initialized before dynamic batch manager """ - self.engine = tp_engine - self.max_total_token_num = max_total_token_num - running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2 - self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list) - # all the inputs should be put into req_queue: waiting req list - - self.running_batch: Batch = running_batch - self.eos_id = eos_id - self.has_wait_tokens = 0 - self.max_wait_tokens = 10 - self.model = model - - self.stats_tool = Stats(log_stats, log_stats_interval) - self.mem_usage_interval = log_stats_interval * 2 - self.tokenizer = get_tokenizer(tokenizer_name=self.model) - - def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str): - """ - Add new request to req queue, during initialization all requests are held in waiting list. - """ - req = Req(request_id, prompt_ids, sampling_params, prompts) - self.req_queue.append(req) - return - - def add_input(self, request_id, prompts, sampling_params): - """ - Encode and Add new input to req queue. support one sequence input for now. - """ - print("prompt", prompts) - prompt_ids = self.tokenizer.encode(prompts) - prompt_len = len(prompt_ids) - if prompt_len > self.engine.max_input_len: - raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}") - sampling_params.stop_sentences_to_token_ids(self.tokenizer) - self.add_req(prompt_ids, sampling_params, request_id, prompts) - return - - def abort(self, request_id): - if self.running_batch is not None: - for req in self.running_batch.reqs: - if req.request_id == request_id: - req.has_generate_finished = True - req.aborted = True - for req in self.req_queue.waiting_req_list: - if req.request_id == request_id: - req.has_generate_finished = True - req.aborted = True + super().__init__( + tp_engine, + max_total_token_num, + batch_max_tokens, + eos_id, + model, + log_stats, + log_stats_interval, + running_batch, + waiting_req_list, + ) def _step(self): """ @@ -123,28 +83,6 @@ def _step(self): return outputs return None - def _init_batch(self, batch: Batch, dtype="fp16"): - reqs = [r.to_rpc_obj() for r in batch.reqs] - batch_id = batch.batch_id - - import torch - - if dtype == "fp16": - dtype = torch.float16 - else: - assert False, "error dtype" - - batch_data = InferBatch.init_batch( - batch_id, - reqs, - dtype, - torch.cuda.current_device(), - self.engine.cache_manager, - self.engine.model.config.vocab_size, - self.engine.max_input_len + self.engine.max_output_len, - ) - self.engine.cache[batch_id] = batch_data - def _prefill_batch(self, batch): """ For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. @@ -171,34 +109,6 @@ def _decode_batch(self, batch: Batch): outputs = self._handle_finish_req(batch, has_new_finished_req) return has_new_finished_req, outputs - def _filter_batch(self, batch: Batch): - batch_id = batch.batch_id - req_id_list = [r.request_id for r in batch.reqs] - batch = self.engine.cache.pop(batch_id) - filter_batch = batch.filter(req_id_list) - del batch - self.engine.cache[batch_id] = filter_batch - - def _merge_batch(self, batch1, batch2): - """ - Merge new mini batch into running batch. - """ - batch1 = self.engine.cache.pop(batch1.batch_id) - batch2 = self.engine.cache.pop(batch2.batch_id) - - m_batch = InferBatch.merge(batch1, batch2) - self.engine.cache[batch1.batch_id] = m_batch - del batch1 - del batch2 - - def _remove_batch(self, batch): - """ - Remove finished batch. - """ - batch = self.engine.cache.pop(batch.batch_id) - batch.free_self() - del batch - def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: finished_reqs = batch.filter_finished() @@ -209,17 +119,6 @@ def _handle_finish_req(self, batch: Batch, has_new_finished_req): return self._output_process(finished_reqs) return None - def _filter_runing_batch(self): - if self.running_batch is not None and self.running_batch.is_clear(): - self.running_batch = None - - def _add_token_id_to_req(self, batch: Batch, req_ans): - for req_id, (new_token_id, new_gen_metadata) in req_ans.items(): - req = batch.id_to_reqs[req_id] - req.output_ids.append(new_token_id) - req.output_metadata_list.append(new_gen_metadata) - return - def _output_process(self, finished_reqs: List[Req]): """ Process the output of a batch. @@ -230,20 +129,6 @@ def _output_process(self, finished_reqs: List[Req]): outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output)) return outputs - def clean_up(self): - # this logic should be implemented in the future. - pass - - def generate(self, prompts, sampling_params, request_id): - """ - Generate the output of a request. - """ - self.add_input(request_id, sampling_params, prompts) - return self.loop_for_fwd() - - def is_running(self): - return self.running_batch is not None or self.req_queue.waiting_req_list - def start_dynamic_batching(args, tp_engine, waiting_req_list): try: diff --git a/colossalai/inference/dynamic_batching/get_tokenizer.py b/colossalai/inference/dynamic_batching/get_tokenizer.py index af1f26848b3a..0b314f8f886a 100644 --- a/colossalai/inference/dynamic_batching/get_tokenizer.py +++ b/colossalai/inference/dynamic_batching/get_tokenizer.py @@ -1,6 +1,11 @@ +""" +Motivated by VllM, This module is trying to resolve the tokenizer issue. +license: MIT, see LICENSE for more details. +""" + from transformers import AutoTokenizer -_FAST_LLAMA_TOKENIZER = "/home/lccd/share/llama-tokenizer" +_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" def get_tokenizer( diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 4c90727adba7..7bced561a91d 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -59,11 +59,10 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques self.req_queue.append(req) return - def add_input(self, request_id, sampling_params, prompts): + def add_input(self, request_id, prompts, sampling_params): """ Encode and Add new input to req queue. support one sequence input for now. """ - print("promptssssss", prompts) prompt_ids = self.tokenizer.encode(prompts) prompt_len = len(prompt_ids) if prompt_len > self.engine.max_input_len: @@ -258,9 +257,10 @@ def generate(self, prompts, sampling_params, request_id): """ self.add_input(request_id, sampling_params, prompts) return self.loop_for_fwd() - + def is_running(self): - return self.running_batch is not None or self.req_queue.waiting_req_list + return self.running_batch is not None or self.req_queue.waiting_req_list + def start_dynamic_batching(args, tp_engine, waiting_req_list): try: diff --git a/tests/test_infer/test_dynamic_batching/config.yaml b/tests/test_infer/test_dynamic_batching/config.yaml index 7af338b38df0..c31ae8c5fadb 100644 --- a/tests/test_infer/test_dynamic_batching/config.yaml +++ b/tests/test_infer/test_dynamic_batching/config.yaml @@ -1,5 +1,5 @@ engine_config: - model: "/mnt/vepfs/lczyh/models/llama-7b-hf" + model: MODEL_PATH tensor_parallel_size: 2 max_batch_size: 4 max_input_len: 128 @@ -12,4 +12,4 @@ router_config: eos_id: 0 disable_log_stats: False log_stats_interval: 10 - model: "/mnt/vepfs/lczyh/models/llama-7b-hf" + model: MODEL_PATH diff --git a/tests/test_infer/test_dynamic_batching/test_async_engine.py b/tests/test_infer/test_dynamic_batching/test_async_engine.py index fd3880b04957..f2f1212b5d07 100644 --- a/tests/test_infer/test_dynamic_batching/test_async_engine.py +++ b/tests/test_infer/test_dynamic_batching/test_async_engine.py @@ -14,9 +14,9 @@ def run_async_engine(path: str): - print(f"Using yaml file {path}") if not os.path.exists(path): raise FileNotFoundError(f"Invalid yaml file path {path}") + config = RayInitConfig.from_yaml_path(path) engine_config = config.engine_config_data model = engine_config.model @@ -32,7 +32,7 @@ async def get_result(engine, prompt, sampling_params): request_id = str(uuid.uuid4().hex) results = engine.generate(request_id, prompt, sampling_params) async for result in results: - print("result: ", result) + assert result is not None async def asy_for_loop_test(config, prompt, sampling_params): From 2781e97ce295b65e79c756e0e5d7ddd69f8ebaba Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 18 Oct 2023 18:42:32 +0800 Subject: [PATCH 14/15] finish test --- .../dynamic_batching/get_tokenizer.py | 3 ++- .../inference/dynamic_batching/infer_batch.py | 16 ++++++------- .../inference/dynamic_batching/io_struct.py | 24 +++---------------- .../dynamic_batching/ray_dist_init.py | 20 ---------------- .../inference/dynamic_batching/req_queue.py | 4 +++- .../dynamic_batching/sampling_params.py | 17 +++++++------ .../inference/dynamic_batching/stats.py | 2 ++ colossalai/inference/manager.py | 4 +++- .../test_async_engine.py | 2 +- .../test_dynamic_batching_manager.py | 1 + .../test_dynamic_batching/test_ray_dist.py | 8 +++---- 11 files changed, 37 insertions(+), 64 deletions(-) diff --git a/colossalai/inference/dynamic_batching/get_tokenizer.py b/colossalai/inference/dynamic_batching/get_tokenizer.py index 0b314f8f886a..94aa3f24393f 100644 --- a/colossalai/inference/dynamic_batching/get_tokenizer.py +++ b/colossalai/inference/dynamic_batching/get_tokenizer.py @@ -1,5 +1,6 @@ """ -Motivated by VllM, This module is trying to resolve the tokenizer issue. +Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue. + license: MIT, see LICENSE for more details. """ diff --git a/colossalai/inference/dynamic_batching/infer_batch.py b/colossalai/inference/dynamic_batching/infer_batch.py index 826272db3e11..112784c15f84 100644 --- a/colossalai/inference/dynamic_batching/infer_batch.py +++ b/colossalai/inference/dynamic_batching/infer_batch.py @@ -1,15 +1,16 @@ +# Adapted from https://github.com/ModelTC/lightllm + import collections from dataclasses import dataclass -from typing import Dict, List , Tuple +from typing import Dict, List, Tuple import numpy as np import torch from colossalai.inference.tensor_parallel import MemoryManager -# make batch infer state an attr of InferBatch - +# make batch infer state an attr of InferBatch class InferSamplingParams: def __init__( self, @@ -65,7 +66,7 @@ def init_batch( cache_manager: MemoryManager, vocab_size: int, max_total_len: int, - ) -> 'InferBatch': + ) -> "InferBatch": input_lengths = [] all_input_ids = [] requests_idx_mapping = {} @@ -76,7 +77,7 @@ def init_batch( nopad_total_token_num = 0 nopad_max_len_in_batch = 0 nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda") - # to avoid memory leak , we pre-allocate 12 more space for each batch. + # to avoid memory leak , we pre-allocate 12 more space for each batch. nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda") for i, r in enumerate(requests): # request id -> idx in list mapping @@ -142,10 +143,9 @@ def free_self(self) -> None: ) remove_index = torch.cat(remove_index, dim=-1) self.cache_manager.free(remove_index) - @torch.no_grad() - def filter(self, request_ids: List[int]) -> 'InferBatch': + def filter(self, request_ids: List[int]) -> "InferBatch": """ Filter finished batch and return a new InferBatch with left ones. """ @@ -226,7 +226,7 @@ def filter(self, request_ids: List[int]) -> 'InferBatch': @classmethod @torch.no_grad() - def merge(cls, batch1, batch2) -> 'InferBatch': + def merge(cls, batch1, batch2) -> "InferBatch": """ Return megerd new InferBatch """ diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 0fdb1141f038..a75eb8007a02 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -1,10 +1,12 @@ +# Adapted from https://github.com/ModelTC/lightllm + from typing import Dict, List, Tuple from .sampling_params import SamplingParams class Req: - def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str): + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""): self.request_id = request_id self.prompt_ids = prompt_ids self.input_len = len(prompt_ids) @@ -49,26 +51,6 @@ def __repr__(self): return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, " -class ReqDetokenizationState: - def __init__( - self, - request_id: str, - prompt_ids: List[int], - max_output_len: int, - ignore_eos: bool, - ) -> None: - self.request_id = request_id - self.prompt_ids = prompt_ids - self.output_ids = [] - self.output_tokens = [] - self.output_str = "" - self.sub_texts = [] - self.current_sub_text = [] - self.max_output_len = max_output_len - self.ignore_eos = ignore_eos - self.gen_metadata = {} - - class Batch: def __init__(self, batch_id, reqs: List[Req]): self.batch_id = batch_id diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index fc150c2d53d4..7639633eaa79 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -134,26 +134,6 @@ def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClas collective.create_collective_group(self.workers, **_options) _ = ray.get(init_rets) - # set batch wait delay in seconds and maximum number of sequences in a batch - # def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): - # results = ray.get([w.generate.remote(request_id, prompt, sampling_params) for w in self.workers]) - # text_res = results[0] # get any one of the copies - # return text_res - - # async def async_generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): - - # async def async_get(worker, request_id, prompt, sampling_params): - # out = await worker.generate_stream.remote(request_id, prompt, sampling_params) - # print("async_get out: ", out) - # return out - - # all_outputs = [] - # for worker in self.workers: - # all_outputs.append() - # all_outputs = await asyncio.gather(*all_outputs) - # text_res = all_outputs[0] # get any one of the copies - # return text_res - def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): ray.get([w.add_input.remote(request_id, prompt, sampling_params) for w in self.workers]) diff --git a/colossalai/inference/dynamic_batching/req_queue.py b/colossalai/inference/dynamic_batching/req_queue.py index d9e9b6269cc4..0de43bd1a21f 100644 --- a/colossalai/inference/dynamic_batching/req_queue.py +++ b/colossalai/inference/dynamic_batching/req_queue.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/ModelTC/lightllm + import uuid from typing import List @@ -41,7 +43,7 @@ def _can_add_new_req(self, req): need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() # NOTE: change here < to <= return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size - + def generate_new_batch(self, current_batch: Batch = None): if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size: return None diff --git a/colossalai/inference/dynamic_batching/sampling_params.py b/colossalai/inference/dynamic_batching/sampling_params.py index 9a0ace4111dd..2028da907259 100644 --- a/colossalai/inference/dynamic_batching/sampling_params.py +++ b/colossalai/inference/dynamic_batching/sampling_params.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/ModelTC/lightllm + """Sampling parameters for text generation.""" from typing import List, Optional, Union @@ -5,7 +7,6 @@ class SamplingParams: - def __init__( self, do_sample: bool = False, @@ -13,10 +14,10 @@ def __init__( frequency_penalty: float = 0.0, temperature: float = 1.0, top_p: float = 1.0, - top_k: int = -1, # -1 is for all + top_k: int = -1, # -1 is for all ignore_eos: bool = False, max_new_tokens: int = 16, - stop_sequences: Optional[Union[str, List[str]]] = None # conditions to stop generation + stop_sequences: Optional[Union[str, List[str]]] = None, # conditions to stop generation ) -> None: self.do_sample = do_sample self.presence_penalty = presence_penalty @@ -31,11 +32,13 @@ def __init__( self.temperature = 1.0 self.top_p = 1.0 self.top_k = 1 - if self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS: # temperature is too slow, change to greedy search + if ( + self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS + ): # temperature is too slow, change to greedy search self.temperature = 1.0 self.top_k = 1 return - + def verify(self): if self.presence_penalty < 0.0: raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}") @@ -60,13 +63,13 @@ def stop_sentences_to_token_ids(self, tokenizer): new_stop_sequences = [] for stop_str in self.stop_sequences: stop_str_ids = tokenizer.encode(stop_str) - if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id + if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id stop_str_ids = stop_str_ids[1:] if len(stop_str_ids) > 0: new_stop_sequences.append(stop_str_ids) self.stop_sequences = new_stop_sequences return - + def to_dict(self): ret = {} ret["do_sample"] = self.do_sample diff --git a/colossalai/inference/dynamic_batching/stats.py b/colossalai/inference/dynamic_batching/stats.py index 6d34183f47c4..524072861a3f 100644 --- a/colossalai/inference/dynamic_batching/stats.py +++ b/colossalai/inference/dynamic_batching/stats.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/ModelTC/lightllm + import time diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 7bced561a91d..42ff8bf1e9ef 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/ModelTC/lightllm + import time from typing import List @@ -51,7 +53,7 @@ def __init__( self.mem_usage_interval = log_stats_interval * 2 self.tokenizer = get_tokenizer(tokenizer_name=self.model) - def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str): + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str = ""): """ Add new request to req queue, during initialization all requests are held in waiting list. """ diff --git a/tests/test_infer/test_dynamic_batching/test_async_engine.py b/tests/test_infer/test_dynamic_batching/test_async_engine.py index f2f1212b5d07..148d325a1d9a 100644 --- a/tests/test_infer/test_dynamic_batching/test_async_engine.py +++ b/tests/test_infer/test_dynamic_batching/test_async_engine.py @@ -15,7 +15,7 @@ def run_async_engine(path: str): if not os.path.exists(path): - raise FileNotFoundError(f"Invalid yaml file path {path}") + return config = RayInitConfig.from_yaml_path(path) engine_config = config.engine_config_data diff --git a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py index 124f1f478b00..588922b5a58f 100644 --- a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py +++ b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py @@ -45,6 +45,7 @@ def run(): log_stats=False, log_stats_interval=10, waiting_req_list=waiting_list, + model="llama", ) before_add = len(dynamic_batch_manager.req_queue) diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index 3e9573816309..5c84b39d8f8e 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -16,7 +16,7 @@ def run_ray_dist(path: str): print(f"Using yaml file {path}") if not os.path.exists(path): - raise FileNotFoundError(f"Invalid yaml file path {path}") + return config = RayInitConfig.from_yaml_path(path) router_config = config.router_config_data engine_config = config.engine_config_data @@ -37,17 +37,17 @@ async def get_result(request_id, prompt, sampling_params): if test_async: print("test_async: ", test_async) result = asyncio.run(get_result(request_id, prompt, sampling_params)) - assert result is not None + assert result is not None print("result: ", result) else: print("test_async: ", test_async) result = driver.generate(request_id, prompt, sampling_params) assert result is not None print("result: ", result) - + is_running = None is_running = driver.is_running() - assert is_running is not None + assert is_running is not None print("is_running: ", is_running) From 8c53573f2c43349492d61eb0c8628614dc522961 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 19 Oct 2023 10:03:12 +0800 Subject: [PATCH 15/15] add license --- colossalai/inference/async_engine.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py index 98af9df6cc8d..a58dde01d250 100644 --- a/colossalai/inference/async_engine.py +++ b/colossalai/inference/async_engine.py @@ -25,20 +25,15 @@ def init_event(self): def add_request(self, request_id: str): """Add a request to be sent to the engine on the next background loop iteration.""" - # if request_id in self._requests: - # raise KeyError(f"Request {request_id} already exists.") - self._requests.put_nowait(request_id) - self.new_requests_event.set() - - def abort_request(self, request_id: str, *, verbose: bool = False) -> None: - """Abort a request during next background loop iteration.""" - if verbose: - logger.info(f"Aborted request {request_id}.") - return + self.new_requests_event.set() # NOTE: we may find a better way to clear this event def add_stop(self): + """ + Add a StopIteration flag to stop async generator. + """ self._finished_requests.put_nowait(StopIteration) + self.new_requests_event.clear() def process_request_output(self, request_output: RequestOutput) -> None: """Process a request output from the engine.""" @@ -61,10 +56,10 @@ async def __anext__(self) -> RequestOutput: class Async_Engine: """ - loop: start listen - add req - remove req - generate--> return async generator + Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager + Background loop: inference reqs in waiting list (Listen) + Request Tracker: manage incoming requests and restore finished ones + Generate: exposed func for add new input and return finished ones """ def __init__( @@ -88,6 +83,9 @@ def _step(self): self._request_tracker.process_request_output(request_output) self._request_tracker.add_stop() + def abort(self, request_id: str): + self.driver.abort(request_id) + def _has_requests_in_progress(self): return self.driver.is_running() @@ -131,5 +129,5 @@ async def generate(self, request_id: str, prompt: str, sampling_params: Sampling except (Exception, asyncio.CancelledError) as e: # If there is an exception or coroutine is cancelled, abort the request. - self._request_tracker.abort_request(request_id) + self.abort_request(request_id) raise e