From ec004fe90cafc89205eee8f849096228c6825c81 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Sat, 14 Oct 2023 12:35:03 +0800 Subject: [PATCH 01/13] Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. --- .../inference/dynamic_batching/io_struct.py | 8 +- colossalai/inference/manager.py | 120 ++++-------------- colossalai/inference/test_async.py | 33 ----- .../test_dynamic_batching/test_forward.py | 10 +- 4 files changed, 32 insertions(+), 139 deletions(-) delete mode 100644 colossalai/inference/test_async.py diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 44ad2964a39f..2b2739f0ae90 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -102,21 +102,17 @@ def mark_finished_req(self, eos_id): has_new_finish = True return has_new_finish - def filter_finished(self)->List[Req]: + def filter_finished(self): """ Filter finished requests from the batch, the finished ones will be removed from 'reqs'. """ # TODO: the logic of return should be defined here. unfinished_req = [] - finished_req = [] for req in self.reqs: if not req.has_generate_finished: - unfinished_req.append(req) - else: - finished_req.append(req) + unfinished_req.append(req) self.reqs = unfinished_req self.id_to_reqs = {req.request_id: req for req in self.reqs} - return finished_req def is_clear(self): return len(self.reqs) == 0 diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 453570c7ec3e..72f77406789f 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,6 +1,5 @@ import time from typing import List -import asyncio from .dynamic_batching.infer_batch import InferBatch from .dynamic_batching.io_struct import Batch, Req @@ -9,8 +8,6 @@ from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine -from transformers import AutoTokenizer -_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" class DynamicBatchManager: def __init__( @@ -57,20 +54,6 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques self.req_queue.append(req) return - def add_input(self, request_id, sampling_params, input_ids): - """ - Encode and Add new input to req queue. support one sequence input for now. - """ - prompt_ids = self.tokenizer.encode(input_ids) - prompt_len = len(prompt_ids) - if prompt_len > self.engine.max_input_len: - raise ValueError( - f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}" - ) - sampling_params.stop_sentences_to_token_ids(self.tokenizer) - self.add_req(prompt_ids, sampling_params, request_id) - return - def abort(self, request_id): if self.running_batch is not None: for req in self.running_batch.reqs: @@ -83,15 +66,13 @@ def abort(self, request_id): req.aborted = True return - async def loop_for_fwd(self): + def loop_for_fwd(self): """ The main loop for a dynamic batching process. """ counter_count = 0 - #self.running_batch is not None or self.req_queue.waiting_req_list - while True: - async for item in self._step(): - yield item + while self.running_batch is not None or self.req_queue.waiting_req_list: + self._step() counter_count += 1 if self.running_batch is not None: if counter_count % self.mem_usage_interval == 0: @@ -106,26 +87,6 @@ async def loop_for_fwd(self): if self.running_batch is None: time.sleep(0.1) # 10ms - def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): - if tokenizer is not None: - self.tokenizer = tokenizer - else: - if "llama" in tokenizer_name.lower() and use_fast == True: - print( - "For some LLaMA-based models, initializing the fast tokenizer may " - "take a long time. To eliminate the initialization time, consider " - f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " - "tokenizer. This is done automatically in Colossalai.") - - tokenizer_name = _FAST_LLAMA_TOKENIZER - - try: - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - except TypeError as e: - use_fast = False - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - - def _step(self): """ Logic for handling requests @@ -136,14 +97,14 @@ def _step(self): if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch - yield from self._prefill_batch(self.running_batch) + self._prefill_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) - yield from self._decode_batch(self.running_batch) + self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 return @@ -151,18 +112,17 @@ def _step(self): new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) - yield from self._prefill_batch(new_mini_batch) + self._prefill_batch(new_mini_batch) if not new_mini_batch.is_clear(): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 - else: self.stats_tool.count_output_tokens(self.running_batch) - yield from self._decode_batch(self.running_batch) + self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 - + return def _init_batch(self, batch: Batch, dtype="fp16"): @@ -198,8 +158,7 @@ def _prefill_batch(self, batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - yield from self._handle_finish_req(batch, has_new_finished_req) - + self._handle_finish_req(batch, has_new_finished_req) # delete finished reqs def _decode_batch(self, batch: Batch): @@ -210,7 +169,7 @@ def _decode_batch(self, batch: Batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - yield from self._handle_finish_req(batch, has_new_finished_req) + self._handle_finish_req(batch, has_new_finished_req) def _filter_batch(self, batch: Batch): batch_id = batch.batch_id @@ -242,13 +201,11 @@ def _remove_batch(self, batch): def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: - finished_reqs=batch.filter_finished() + batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) - yield from self._output_process(finished_reqs) - def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): @@ -261,47 +218,26 @@ def _add_token_id_to_req(self, batch: Batch, req_ans): req.output_metadata_list.append(new_gen_metadata) return - async def _output_process(self, finished_reqs: List[Req]): - """ - Process the output of a batch. - """ - for req in finished_reqs: - output = self.tokenizer.decode(req.output_ids) - yield output, req.request_id, req.output_metadata_list - def clean_up(self): # this logic should be implemented in the future. pass - async def generate(self,request_id,prompt_id,sampling_params): - """ - Generate the output of a request. - """ - self.add_input(request_id,prompt_id,sampling_params) - def start_dynamic_batching(args, tp_engine, waiting_req_list): - try: - batch_manager = DynamicBatchManager( - tp_engine=tp_engine, - max_total_token_num=args.max_total_token_num, - batch_max_tokens=args.batch_max_tokens, - eos_id=args.eos_id, - log_stats=not args.disable_log_stats, - log_stats_interval=args.log_stats_interval, - waiting_req_list=waiting_req_list, - ) - - except Exception: - batch_manager.clean_up() - raise - - batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__) - prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world")) - - asyncio.run(prod_task) - - for item in batch_manager.loop_for_fwd(): - print(item) - - return batch_manager + # try: + batch_manager = DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + # except Exception: + # batch_manager.clean_up() + # raise + + batch_manager.loop_for_fwd() + return diff --git a/colossalai/inference/test_async.py b/colossalai/inference/test_async.py deleted file mode 100644 index 08720f36da22..000000000000 --- a/colossalai/inference/test_async.py +++ /dev/null @@ -1,33 +0,0 @@ -import asyncio - -shared_list = [] - -async def producer(): - for i in range(5): - await asyncio.sleep(1) # 模拟异步获取数据的操作 - shared_list.append(i) - print(f"Produced {i}") - -async def consumer(): - last_index = 0 - while True: - await asyncio.sleep(0.5) # 为了不使循环过于紧凑,增加了小的延迟 - if last_index < len(shared_list): - item = shared_list[last_index] - print(f"Consumed {item}") - yield item - last_index += 1 - -async def main(): - # 创建生产者和消费者任务 - prod_task = asyncio.create_task(producer()) - - # 等待生产者任务完成 - await prod_task - - async for data in consumer(): - print(data) - # 为了示例的目的,我们只等待一段时间,然后停止消费者 - await asyncio.sleep(5) - -asyncio.run(main()) diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py index ca6401259831..63df491e5b52 100644 --- a/tests/test_infer/test_dynamic_batching/test_forward.py +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -42,7 +42,7 @@ def run(): waiting_list.append(req2) waiting_list.append(req3) waiting_list.append(req4) - + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) model = LlamaForCausalLM(llama_config) model = model.half() @@ -50,13 +50,7 @@ def run(): shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) - manager._set_tokenizer(tokenizer_name = model.__class__.__name__) - result_generator = manager.loop_for_fwd() - for result in result_generator: - print(result) - - + start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) def check_dynamic_forward(rank, world_size, port): From d97290af8ade2dcd433cdd8e757fd494fe0edd7b Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Sat, 14 Oct 2023 13:53:51 +0800 Subject: [PATCH 02/13] Add Ray Distributed Environment Init Scripts --- .../inference/dynamic_batching/io_struct.py | 11 +- .../dynamic_batching/ray_dist_init.py | 115 +++++++++++++++++ .../dynamic_batching/ray_init_config.py | 53 ++++++++ colossalai/inference/manager.py | 121 ++++++++++++++---- .../test_dynamic_batching/config.yaml | 15 +++ .../test_dynamic_batching/test_ray_dist.py | 30 +++++ 6 files changed, 314 insertions(+), 31 deletions(-) create mode 100644 colossalai/inference/dynamic_batching/ray_dist_init.py create mode 100644 colossalai/inference/dynamic_batching/ray_init_config.py create mode 100644 tests/test_infer/test_dynamic_batching/config.yaml create mode 100644 tests/test_infer/test_dynamic_batching/test_ray_dist.py diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 2b2739f0ae90..63165d0a3e5a 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -4,7 +4,7 @@ class Req: - def __init__(self, request_id, prompt_ids, sample_params: SamplingParams): + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str): self.request_id = request_id self.prompt_ids = prompt_ids self.input_len = len(prompt_ids) @@ -14,6 +14,7 @@ def __init__(self, request_id, prompt_ids, sample_params: SamplingParams): self.output_metadata_list = [] self.has_generate_finished = False self.aborted = False + self.prompts = prompts def to_rpc_obj(self): return { @@ -102,17 +103,21 @@ def mark_finished_req(self, eos_id): has_new_finish = True return has_new_finish - def filter_finished(self): + def filter_finished(self)->List[Req]: """ Filter finished requests from the batch, the finished ones will be removed from 'reqs'. """ # TODO: the logic of return should be defined here. unfinished_req = [] + finished_req = [] for req in self.reqs: if not req.has_generate_finished: - unfinished_req.append(req) + unfinished_req.append(req) + else: + finished_req.append(req) self.reqs = unfinished_req self.id_to_reqs = {req.request_id: req for req in self.reqs} + return finished_req def is_clear(self): return len(self.reqs) == 0 diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py new file mode 100644 index 000000000000..0359d162f138 --- /dev/null +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -0,0 +1,115 @@ +import logging +import os + +import ray +import ray.util.collective as collective +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import free_port + +from colossalai.inference.manager import start_dynamic_batching +from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass + +ray_serve_logger = logging.getLogger("ray.serve") + +def log_cuda_info(scope_name: str): + ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}") + ray_serve_logger.info( + f" {scope_name}: CUDA_VISIBLE_DEVICES: {os.getenv('CUDA_VISIBLE_DEVICES', 'NO DEVICES FOUND!')}" + ) + if torch.cuda.is_available(): + ray_serve_logger.info( + f" {scope_name}: cuda current_device: {torch.cuda.current_device()}, cuda device count: {torch.cuda.device_count()}" + ) + else: + ray_serve_logger.info(f" {scope_name}: cuda is not available!") + +@ray.remote(num_gpus=1) +class Worker: + def __init__(self, model_path: str, tensor_parallel_size: int, max_batch_size: int, max_input_len: int, max_output_len: int, router_config: RooterArgsClass): + log_cuda_info("Worker.init") + self.tensor_parallel_size = tensor_parallel_size + self.model_path = model_path + self.max_batch_size = max_batch_size + self.max_input_len = max_input_len + self.max_output_len = max_output_len + self.router_config = router_config + + def setup(self, world_size, rank, port): + + # initialize a ray collective group, otherwise colossalai distributed env won't be built successfully + collective.init_collective_group(world_size, rank, "nccl", "default") + # initialize and set distributed environment + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..") + log_cuda_info("Worker.setup") + + # Load model + self.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16 + ) + + shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True) + self.infer_engine = TPInferEngine( + self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len + ) + self.start_dynamic_batching = start_dynamic_batching(self.router_config, self.infer_engine, []) + + return True + + def generate(self, request_id, prompt, sampling_params) -> str: + + ray_serve_logger.info(f"text: {prompt}") + + results_generator = self.start_dynamic_batching.generate(prompt, sampling_params, request_id) + + final_output = None + for request_output in results_generator: + final_output = request_output + + assert final_output is not None + ray_serve_logger.info(f"Generated text: {final_output}") + return final_output + +class Driver: + def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClass): + log_cuda_info("Driver:init") + model_path = engine_config.model + tensor_parallel_size = engine_config.tensor_parallel_size + + self.num_workers = tensor_parallel_size + self.workers = [] + init_rets = [] + + # Just grab a free port on localhost + # NOTE workers in this communication group listen to the same port + available_port = free_port() + + for i in range(self.num_workers): + worker_name = "worker_idx_{}".format(i) + w = Worker.options(name=worker_name).remote( + model_path, self.num_workers, engine_config.max_batch_size, engine_config.max_input_len, engine_config.max_output_len, router_config + ) + self.workers.append(w) + init_rets.append(w.setup.remote(self.num_workers, i, available_port)) + _options = { + "group_name": "default_driver", + "world_size": self.num_workers, + "ranks": [i for i in range(self.num_workers)], + "backend": "nccl", + } + collective.create_collective_group(self.workers, **_options) + _ = ray.get(init_rets) + + # set batch wait delay in seconds and maximum number of sequences in a batch + def generate(self, request_id, prompt, sampling_params): + results = ray.get([w.generate.remote(request_id, prompt, sampling_params) for w in self.workers]) + text_res = results[0] # get any one of the copies + return text_res \ No newline at end of file diff --git a/colossalai/inference/dynamic_batching/ray_init_config.py b/colossalai/inference/dynamic_batching/ray_init_config.py new file mode 100644 index 000000000000..0e89d759e987 --- /dev/null +++ b/colossalai/inference/dynamic_batching/ray_init_config.py @@ -0,0 +1,53 @@ +import logging + +import yaml +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +class EngineArgsClass(BaseModel): + """Config for Engine""" + model: str + tensor_parallel_size: int = 2 + max_batch_size: int = 4 + max_input_len: int = 128 + max_output_len: int = 32 + +class RooterArgsClass(BaseModel): + """Config for Rooter""" + max_total_token_num: int = 42 + batch_max_tokens: int = 42 + eos_id: int = 0 + disable_log_stats: bool = False + log_stats_interval: int = 10 + model: str + +class RayInitConfig(BaseModel): + """All-together configs without app router config""" + + engine_config_data: EngineArgsClass + router_config_data: RooterArgsClass + + @classmethod + def from_yaml_path(cls, path: str): + try: + with open(path, "r") as yaml_file: + try: + config = yaml.safe_load(yaml_file) + # serve deployment config + engine_config = config.get("engine_config", {}) + router_config = config.get("router_config", {}) + + return cls( + engine_config_data=engine_config, + router_config_data=router_config, + ) + except yaml.YAMLError as e: + logger.error(f"An Error occurred when parsing yaml: {e}") + raise + except FileNotFoundError: + logger.error(f"The file '{path}' does not exist!") + raise + except OSError as e: + logger.error(f"An Error occurred: {e}") + raise diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 72f77406789f..29af3ae1f934 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -8,6 +8,8 @@ from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine +from transformers import AutoTokenizer +_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" class DynamicBatchManager: def __init__( @@ -16,6 +18,7 @@ def __init__( max_total_token_num, batch_max_tokens, eos_id, + model, log_stats=True, log_stats_interval=10, running_batch: Batch = None, @@ -27,6 +30,7 @@ def __init__( batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine eos_id : The end token of a seq + model: the model weight dir path, the app will load config, weights and tokenizer from this dir log_stats : whether to log stats log_stats_interval : log stats interval running_batch : running batch @@ -42,18 +46,35 @@ def __init__( self.eos_id = eos_id self.has_wait_tokens = 0 self.max_wait_tokens = 10 + self.model = model self.stats_tool = Stats(log_stats, log_stats_interval) self.mem_usage_interval = log_stats_interval * 2 + self._set_tokenizer(tokenizer_name=self.model) - def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str): + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str): """ Add new request to req queue, during initialization all requests are held in waiting list. """ - req = Req(request_id, prompt_ids, sampling_params) + req = Req(request_id, prompt_ids, sampling_params, prompts) self.req_queue.append(req) + print("len(self.req_queue): ", len(self.req_queue)) return + def add_input(self, request_id, sampling_params, prompts): + """ + Encode and Add new input to req queue. support one sequence input for now. + """ + prompt_ids = self.tokenizer.encode(prompts) + prompt_len = len(prompt_ids) + if prompt_len > self.engine.max_input_len: + raise ValueError( + f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}" + ) + sampling_params.stop_sentences_to_token_ids(self.tokenizer) + self.add_req(prompt_ids, sampling_params, request_id, prompts) + return + def abort(self, request_id): if self.running_batch is not None: for req in self.running_batch.reqs: @@ -71,8 +92,14 @@ def loop_for_fwd(self): The main loop for a dynamic batching process. """ counter_count = 0 + #self.running_batch is not None or self.req_queue.waiting_req_list while self.running_batch is not None or self.req_queue.waiting_req_list: - self._step() + if self.running_batch is not None : + print("len(self.running_batch): ", len(self.running_batch)) + else: + print("len(self.running_batch): ", 0) + print("len(self.req_queue.waiting_req_list): ", len(self.req_queue.waiting_req_list)) + yield from self._step() counter_count += 1 if self.running_batch is not None: if counter_count % self.mem_usage_interval == 0: @@ -87,6 +114,26 @@ def loop_for_fwd(self): if self.running_batch is None: time.sleep(0.1) # 10ms + def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): + if tokenizer is not None: + self.tokenizer = tokenizer + else: + if "llama" in tokenizer_name.lower() and use_fast == True: + print( + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai.") + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + except TypeError as e: + use_fast = False + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + + def _step(self): """ Logic for handling requests @@ -97,14 +144,14 @@ def _step(self): if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch - self._prefill_batch(self.running_batch) + yield from self._prefill_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) - self._decode_batch(self.running_batch) + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 return @@ -112,17 +159,18 @@ def _step(self): new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) - self._prefill_batch(new_mini_batch) + yield from self._prefill_batch(new_mini_batch) if not new_mini_batch.is_clear(): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 + else: self.stats_tool.count_output_tokens(self.running_batch) - self._decode_batch(self.running_batch) + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 - + return def _init_batch(self, batch: Batch, dtype="fp16"): @@ -158,7 +206,8 @@ def _prefill_batch(self, batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - self._handle_finish_req(batch, has_new_finished_req) + yield from self._handle_finish_req(batch, has_new_finished_req) + # delete finished reqs def _decode_batch(self, batch: Batch): @@ -169,7 +218,7 @@ def _decode_batch(self, batch: Batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - self._handle_finish_req(batch, has_new_finished_req) + yield from self._handle_finish_req(batch, has_new_finished_req) def _filter_batch(self, batch: Batch): batch_id = batch.batch_id @@ -201,11 +250,13 @@ def _remove_batch(self, batch): def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: - batch.filter_finished() + finished_reqs=batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) + yield from self._output_process(finished_reqs) + def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): @@ -218,26 +269,40 @@ def _add_token_id_to_req(self, batch: Batch, req_ans): req.output_metadata_list.append(new_gen_metadata) return + def _output_process(self, finished_reqs: List[Req]): + """ + Process the output of a batch. + """ + for req in finished_reqs: + output = self.tokenizer.decode(req.output_ids) + yield req.prompts + output + def clean_up(self): # this logic should be implemented in the future. pass + def generate(self,prompts,sampling_params,request_id): + """ + Generate the output of a request. + """ + self.add_input(request_id,sampling_params,prompts) + return self.loop_for_fwd() def start_dynamic_batching(args, tp_engine, waiting_req_list): - # try: - batch_manager = DynamicBatchManager( - tp_engine=tp_engine, - max_total_token_num=args.max_total_token_num, - batch_max_tokens=args.batch_max_tokens, - eos_id=args.eos_id, - log_stats=not args.disable_log_stats, - log_stats_interval=args.log_stats_interval, - waiting_req_list=waiting_req_list, - ) - - # except Exception: - # batch_manager.clean_up() - # raise - - batch_manager.loop_for_fwd() - return + try: + batch_manager = DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + model=args.model, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + except Exception: + batch_manager.clean_up() + raise + + return batch_manager diff --git a/tests/test_infer/test_dynamic_batching/config.yaml b/tests/test_infer/test_dynamic_batching/config.yaml new file mode 100644 index 000000000000..0129f036a00f --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/config.yaml @@ -0,0 +1,15 @@ +engine_config: + model: /home/lccd/share/model_data/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348 + tensor_parallel_size: 2 + max_batch_size: 4 + max_input_len: 128 + max_output_len: 32 +# config for app router deployment +# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig? +router_config: + max_total_token_num: 42 + batch_max_tokens: 42 + eos_id: 0 + disable_log_stats: False + log_stats_interval: 10 + model: /home/lccd/share/model_data/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348 diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py new file mode 100644 index 000000000000..d889db44b277 --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -0,0 +1,30 @@ +import os +from typing import Dict +import uuid +from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig +from colossalai.inference.dynamic_batching.ray_dist_init import Driver +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams + +def test_ray_dist(path: str): + print(f"Using yaml file {path}") + if not os.path.exists(path): + raise FileNotFoundError(f"Invalid yaml file path {path}") + config = RayInitConfig.from_yaml_path(path) + router_config = config.router_config_data + engine_config = config.engine_config_data + model = engine_config.model + if model is None or not os.path.exists(model): + raise ValueError("Model path not provided or invalid path!") + + driver = Driver(router_config=router_config, engine_config=engine_config) + prompt = 'Introduce some landmarks in Beijing' + + request_id = str(uuid.uuid4().hex) + + sampling_params = SamplingParams() + + print("result: ", prompt + driver.generate(request_id, prompt, sampling_params)) + +if __name__ == "__main__": + path = "config.yaml" + test_ray_dist(path) From f589e97c94c01dd88054a4df007ed19365d9e77c Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Sat, 14 Oct 2023 19:23:39 +0800 Subject: [PATCH 03/13] support DynamicBatchManager base function --- .../dynamic_batching/ray_dist_init.py | 33 +++++++++++++++++-- colossalai/inference/manager.py | 6 ---- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index 0359d162f138..9701ca2cde5a 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -13,6 +13,8 @@ from colossalai.inference.manager import start_dynamic_batching from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from typing import List ray_serve_logger = logging.getLogger("ray.serve") @@ -64,7 +66,7 @@ def setup(self, world_size, rank, port): return True - def generate(self, request_id, prompt, sampling_params) -> str: + def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> str: ray_serve_logger.info(f"text: {prompt}") @@ -77,6 +79,19 @@ def generate(self, request_id, prompt, sampling_params) -> str: assert final_output is not None ray_serve_logger.info(f"Generated text: {final_output}") return final_output + + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): + self.start_dynamic_batching.add_input(request_id, sampling_params, prompt) + + def abort(self,request_id: str): + self.start_dynamic_batching.abort(request_id) + + def step(self): + self.start_dynamic_batching._step() + + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): + self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt) + class Driver: def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClass): @@ -109,7 +124,19 @@ def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClas _ = ray.get(init_rets) # set batch wait delay in seconds and maximum number of sequences in a batch - def generate(self, request_id, prompt, sampling_params): + def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): results = ray.get([w.generate.remote(request_id, prompt, sampling_params) for w in self.workers]) text_res = results[0] # get any one of the copies - return text_res \ No newline at end of file + return text_res + + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): + ray.get([w.add_input.remote(request_id, sampling_params, prompt) for w in self.workers]) + + def abort(self,request_id: str): + ray.get([w.abort.remote(request_id) for w in self.workers]) + + def step(self): + ray.get([w._step.remote() for w in self.workers]) + + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): + ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) \ No newline at end of file diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index b5ee1d027e37..6678ecae0816 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -56,7 +56,6 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques """ req = Req(request_id, prompt_ids, sampling_params, prompts) self.req_queue.append(req) - print("len(self.req_queue): ", len(self.req_queue)) return def add_input(self, request_id, sampling_params, prompts): @@ -92,11 +91,6 @@ def loop_for_fwd(self): counter_count = 0 #self.running_batch is not None or self.req_queue.waiting_req_list while self.running_batch is not None or self.req_queue.waiting_req_list: - if self.running_batch is not None : - print("len(self.running_batch): ", len(self.running_batch)) - else: - print("len(self.running_batch): ", 0) - print("len(self.req_queue.waiting_req_list): ", len(self.req_queue.waiting_req_list)) yield from self._step() counter_count += 1 if self.running_batch is not None: From c07005074af453f0c59d34dfdb4306fe7fbd29ff Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 11:55:22 +0800 Subject: [PATCH 04/13] revert _set_tokenizer version --- colossalai/inference/manager.py | 42 ++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 6678ecae0816..06eae3ec0ce3 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -8,6 +8,8 @@ from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine +from transformers import AutoTokenizer +_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" class DynamicBatchManager: def __init__( @@ -106,6 +108,26 @@ def loop_for_fwd(self): if self.running_batch is None: time.sleep(0.1) # 10ms + def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): + if tokenizer is not None: + self.tokenizer = tokenizer + else: + if "llama" in tokenizer_name.lower() and use_fast == True: + print( + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai.") + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + except TypeError as e: + use_fast = False + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + + def _step(self): """ Logic for handling requests @@ -116,14 +138,14 @@ def _step(self): if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch - self._prefill_batch(self.running_batch) + yield from self._prefill_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) - self._decode_batch(self.running_batch) + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 return @@ -131,17 +153,18 @@ def _step(self): new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) - self._prefill_batch(new_mini_batch) + yield from self._prefill_batch(new_mini_batch) if not new_mini_batch.is_clear(): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 + else: self.stats_tool.count_output_tokens(self.running_batch) - self._decode_batch(self.running_batch) + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 - + return def _init_batch(self, batch: Batch, dtype="fp16"): @@ -177,7 +200,8 @@ def _prefill_batch(self, batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - self._handle_finish_req(batch, has_new_finished_req) + yield from self._handle_finish_req(batch, has_new_finished_req) + # delete finished reqs def _decode_batch(self, batch: Batch): @@ -188,7 +212,7 @@ def _decode_batch(self, batch: Batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - self._handle_finish_req(batch, has_new_finished_req) + yield from self._handle_finish_req(batch, has_new_finished_req) def _filter_batch(self, batch: Batch): batch_id = batch.batch_id @@ -220,11 +244,13 @@ def _remove_batch(self, batch): def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: - batch.filter_finished() + finished_reqs=batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) + yield from self._output_process(finished_reqs) + def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): From 5deb95ced8c76fc1b49ef6b24bd0365d1d59459f Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 14:03:06 +0800 Subject: [PATCH 05/13] add driver async generate --- .../dynamic_batching/ray_dist_init.py | 14 ++++++++-- colossalai/inference/manager.py | 26 ++----------------- .../test_dynamic_batching/test_ray_dist.py | 5 +++- 3 files changed, 18 insertions(+), 27 deletions(-) diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index 9701ca2cde5a..63cf8f33c7a8 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -4,7 +4,7 @@ import ray import ray.util.collective as collective import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM import colossalai from colossalai.inference.tensor_parallel.engine import TPInferEngine @@ -14,7 +14,9 @@ from colossalai.inference.manager import start_dynamic_batching from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer from typing import List +import asyncio ray_serve_logger = logging.getLogger("ray.serve") @@ -51,7 +53,7 @@ def setup(self, world_size, rank, port): log_cuda_info("Worker.setup") # Load model - self.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + self.tokenizer = get_tokenizer(tokenizer_name = self.model_path) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( @@ -129,6 +131,14 @@ def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams text_res = results[0] # get any one of the copies return text_res + async def async_generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): + all_outputs = [] + for worker in self.workers: + all_outputs.append(worker.generate.remote(request_id, prompt, sampling_params)) + all_outputs = await asyncio.gather(*all_outputs) + text_res = all_outputs[0]# get any one of the copies + return text_res + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): ray.get([w.add_input.remote(request_id, sampling_params, prompt) for w in self.workers]) diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 06eae3ec0ce3..26d93eb1f14a 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -7,9 +7,7 @@ from .dynamic_batching.sampling_params import SamplingParams from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine - -from transformers import AutoTokenizer -_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" +from .dynamic_batching.get_tokenizer import get_tokenizer class DynamicBatchManager: def __init__( @@ -50,7 +48,7 @@ def __init__( self.stats_tool = Stats(log_stats, log_stats_interval) self.mem_usage_interval = log_stats_interval * 2 - self._set_tokenizer(tokenizer_name=self.model) + self.tokenizer = get_tokenizer(tokenizer_name=self.model) def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str): """ @@ -108,26 +106,6 @@ def loop_for_fwd(self): if self.running_batch is None: time.sleep(0.1) # 10ms - def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): - if tokenizer is not None: - self.tokenizer = tokenizer - else: - if "llama" in tokenizer_name.lower() and use_fast == True: - print( - "For some LLaMA-based models, initializing the fast tokenizer may " - "take a long time. To eliminate the initialization time, consider " - f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " - "tokenizer. This is done automatically in Colossalai.") - - tokenizer_name = _FAST_LLAMA_TOKENIZER - - try: - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - except TypeError as e: - use_fast = False - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - - def _step(self): """ Logic for handling requests diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index d889db44b277..c943c74eb456 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -23,7 +23,10 @@ def test_ray_dist(path: str): sampling_params = SamplingParams() - print("result: ", prompt + driver.generate(request_id, prompt, sampling_params)) + result_generator = driver.generate(request_id, prompt, sampling_params) + + for result in result_generator: + print("result: ", result) if __name__ == "__main__": path = "config.yaml" From 306ef77a0c09fd3224ab1dc400ca313ef3db49ef Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 14:38:31 +0800 Subject: [PATCH 06/13] add async test --- .../inference/dynamic_batching/io_struct.py | 8 +++++-- .../test_dynamic_batching/test_ray_dist.py | 21 +++++++++++++------ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index fe5f25e2ea11..63165d0a3e5a 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -103,17 +103,21 @@ def mark_finished_req(self, eos_id): has_new_finish = True return has_new_finish - def filter_finished(self): + def filter_finished(self)->List[Req]: """ Filter finished requests from the batch, the finished ones will be removed from 'reqs'. """ # TODO: the logic of return should be defined here. unfinished_req = [] + finished_req = [] for req in self.reqs: if not req.has_generate_finished: - unfinished_req.append(req) + unfinished_req.append(req) + else: + finished_req.append(req) self.reqs = unfinished_req self.id_to_reqs = {req.request_id: req for req in self.reqs} + return finished_req def is_clear(self): return len(self.reqs) == 0 diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index c943c74eb456..a7bc7b2df246 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -1,9 +1,9 @@ import os -from typing import Dict import uuid from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig from colossalai.inference.dynamic_batching.ray_dist_init import Driver from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +import asyncio def test_ray_dist(path: str): print(f"Using yaml file {path}") @@ -23,11 +23,20 @@ def test_ray_dist(path: str): sampling_params = SamplingParams() - result_generator = driver.generate(request_id, prompt, sampling_params) - - for result in result_generator: - print("result: ", result) + async def get_result(request_id, prompt, sampling_params): + return await driver.generate(request_id, prompt, sampling_params) + + for test_async in [True, False]: + if test_async: + print("test_async: ", test_async) + result = get_result(request_id, prompt, sampling_params) + print("result: ", result) + else: + print("test_async: ", test_async) + result = driver.generate(request_id, prompt, sampling_params) + print("result: ", result) + if __name__ == "__main__": path = "config.yaml" - test_ray_dist(path) + test_ray_dist(path) \ No newline at end of file From 632f0e1107f4d454e5e7f28be72b14b455acd562 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 14:52:56 +0800 Subject: [PATCH 07/13] fix bugs in test_ray_dist.py --- tests/test_infer/test_dynamic_batching/test_ray_dist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index a7bc7b2df246..9bf5ff68b6ae 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -24,12 +24,12 @@ def test_ray_dist(path: str): sampling_params = SamplingParams() async def get_result(request_id, prompt, sampling_params): - return await driver.generate(request_id, prompt, sampling_params) + return await driver.async_generate(request_id, prompt, sampling_params) for test_async in [True, False]: if test_async: print("test_async: ", test_async) - result = get_result(request_id, prompt, sampling_params) + result = asyncio.run(get_result(request_id, prompt, sampling_params)) print("result: ", result) else: print("test_async: ", test_async) From 0b2fe513f3aad5f5b8092ed3989af994caab774b Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 15:10:36 +0800 Subject: [PATCH 08/13] add get_tokenizer.py --- .../dynamic_batching/get_tokenizer.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 colossalai/inference/dynamic_batching/get_tokenizer.py diff --git a/colossalai/inference/dynamic_batching/get_tokenizer.py b/colossalai/inference/dynamic_batching/get_tokenizer.py new file mode 100644 index 000000000000..ea8116ce66f5 --- /dev/null +++ b/colossalai/inference/dynamic_batching/get_tokenizer.py @@ -0,0 +1,22 @@ +from transformers import AutoTokenizer +_FAST_LLAMA_TOKENIZER = "/home/lccd/share/llama-tokenizer" + +def get_tokenizer(tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): + if tokenizer is not None: + tokenizer = tokenizer + else: + if "llama" in tokenizer_name.lower() and use_fast == True: + print( + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai.") + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + except TypeError as e: + use_fast = False + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + return tokenizer \ No newline at end of file From cd843ac8f2fc7e9cedee641a4f110b29f03e9d84 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 15:35:38 +0800 Subject: [PATCH 09/13] fix code style --- .../dynamic_batching/get_tokenizer.py | 42 +++++++----- .../inference/dynamic_batching/io_struct.py | 12 ++-- .../dynamic_batching/ray_dist_init.py | 64 +++++++++++-------- .../dynamic_batching/ray_init_config.py | 5 ++ colossalai/inference/manager.py | 27 ++++---- .../test_dynamic_batching/config.yaml | 4 +- .../test_dynamic_batching/test_ray_dist.py | 30 +++++---- 7 files changed, 109 insertions(+), 75 deletions(-) diff --git a/colossalai/inference/dynamic_batching/get_tokenizer.py b/colossalai/inference/dynamic_batching/get_tokenizer.py index ea8116ce66f5..af1f26848b3a 100644 --- a/colossalai/inference/dynamic_batching/get_tokenizer.py +++ b/colossalai/inference/dynamic_batching/get_tokenizer.py @@ -1,22 +1,34 @@ from transformers import AutoTokenizer + _FAST_LLAMA_TOKENIZER = "/home/lccd/share/llama-tokenizer" - -def get_tokenizer(tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): + + +def get_tokenizer( + tokenizer=None, + tokenizer_name: str = "", + trust_remote_code: bool = False, + use_fast: bool = True, +): if tokenizer is not None: - tokenizer = tokenizer + tokenizer = tokenizer else: if "llama" in tokenizer_name.lower() and use_fast == True: print( - "For some LLaMA-based models, initializing the fast tokenizer may " - "take a long time. To eliminate the initialization time, consider " - f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " - "tokenizer. This is done automatically in Colossalai.") - - tokenizer_name = _FAST_LLAMA_TOKENIZER - - try: - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - except TypeError as e: + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai." + ) + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) + except TypeError: use_fast = False - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - return tokenizer \ No newline at end of file + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) + return tokenizer diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 63165d0a3e5a..9faaad6f111e 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -37,7 +37,11 @@ def stop_sequences_matched(self): if self.sample_params.stop_sequences is not None: for stop_token_ids in self.sample_params.stop_sequences: stop_len = len(stop_token_ids) - if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)): + if ( + stop_len > 0 + and len(self.output_ids) >= stop_len + and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)) + ): return True return False @@ -103,7 +107,7 @@ def mark_finished_req(self, eos_id): has_new_finish = True return has_new_finish - def filter_finished(self)->List[Req]: + def filter_finished(self) -> List[Req]: """ Filter finished requests from the batch, the finished ones will be removed from 'reqs'. """ @@ -112,9 +116,9 @@ def filter_finished(self)->List[Req]: finished_req = [] for req in self.reqs: if not req.has_generate_finished: - unfinished_req.append(req) + unfinished_req.append(req) else: - finished_req.append(req) + finished_req.append(req) self.reqs = unfinished_req self.id_to_reqs = {req.request_id: req for req in self.reqs} return finished_req diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index 63cf8f33c7a8..a40a00e2666c 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -1,5 +1,7 @@ +import asyncio import logging import os +from typing import List import ray import ray.util.collective as collective @@ -7,19 +9,17 @@ from transformers import AutoModelForCausalLM import colossalai +from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer +from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.manager import start_dynamic_batching from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.shardformer import ShardConfig from colossalai.testing import free_port -from colossalai.inference.manager import start_dynamic_batching -from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass -from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer -from typing import List -import asyncio - ray_serve_logger = logging.getLogger("ray.serve") + def log_cuda_info(scope_name: str): ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}") ray_serve_logger.info( @@ -32,9 +32,18 @@ def log_cuda_info(scope_name: str): else: ray_serve_logger.info(f" {scope_name}: cuda is not available!") + @ray.remote(num_gpus=1) class Worker: - def __init__(self, model_path: str, tensor_parallel_size: int, max_batch_size: int, max_input_len: int, max_output_len: int, router_config: RooterArgsClass): + def __init__( + self, + model_path: str, + tensor_parallel_size: int, + max_batch_size: int, + max_input_len: int, + max_output_len: int, + router_config: RooterArgsClass, + ): log_cuda_info("Worker.init") self.tensor_parallel_size = tensor_parallel_size self.model_path = model_path @@ -44,7 +53,6 @@ def __init__(self, model_path: str, tensor_parallel_size: int, max_batch_size: i self.router_config = router_config def setup(self, world_size, rank, port): - # initialize a ray collective group, otherwise colossalai distributed env won't be built successfully collective.init_collective_group(world_size, rank, "nccl", "default") # initialize and set distributed environment @@ -53,7 +61,7 @@ def setup(self, world_size, rank, port): log_cuda_info("Worker.setup") # Load model - self.tokenizer = get_tokenizer(tokenizer_name = self.model_path) + self.tokenizer = get_tokenizer(tokenizer_name=self.model_path) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( @@ -69,7 +77,6 @@ def setup(self, world_size, rank, port): return True def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> str: - ray_serve_logger.info(f"text: {prompt}") results_generator = self.start_dynamic_batching.generate(prompt, sampling_params, request_id) @@ -81,19 +88,19 @@ def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams assert final_output is not None ray_serve_logger.info(f"Generated text: {final_output}") return final_output - + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): self.start_dynamic_batching.add_input(request_id, sampling_params, prompt) - - def abort(self,request_id: str): + + def abort(self, request_id: str): self.start_dynamic_batching.abort(request_id) - + def step(self): self.start_dynamic_batching._step() - + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt) - + class Driver: def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClass): @@ -112,7 +119,12 @@ def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClas for i in range(self.num_workers): worker_name = "worker_idx_{}".format(i) w = Worker.options(name=worker_name).remote( - model_path, self.num_workers, engine_config.max_batch_size, engine_config.max_input_len, engine_config.max_output_len, router_config + model_path, + self.num_workers, + engine_config.max_batch_size, + engine_config.max_input_len, + engine_config.max_output_len, + router_config, ) self.workers.append(w) init_rets.append(w.setup.remote(self.num_workers, i, available_port)) @@ -130,23 +142,23 @@ def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams results = ray.get([w.generate.remote(request_id, prompt, sampling_params) for w in self.workers]) text_res = results[0] # get any one of the copies return text_res - + async def async_generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): all_outputs = [] for worker in self.workers: all_outputs.append(worker.generate.remote(request_id, prompt, sampling_params)) all_outputs = await asyncio.gather(*all_outputs) - text_res = all_outputs[0]# get any one of the copies + text_res = all_outputs[0] # get any one of the copies return text_res - + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): ray.get([w.add_input.remote(request_id, sampling_params, prompt) for w in self.workers]) - - def abort(self,request_id: str): + + def abort(self, request_id: str): ray.get([w.abort.remote(request_id) for w in self.workers]) - + def step(self): ray.get([w._step.remote() for w in self.workers]) - + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): - ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) \ No newline at end of file + ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) diff --git a/colossalai/inference/dynamic_batching/ray_init_config.py b/colossalai/inference/dynamic_batching/ray_init_config.py index 0e89d759e987..471f07330aec 100644 --- a/colossalai/inference/dynamic_batching/ray_init_config.py +++ b/colossalai/inference/dynamic_batching/ray_init_config.py @@ -5,16 +5,20 @@ logger = logging.getLogger(__name__) + class EngineArgsClass(BaseModel): """Config for Engine""" + model: str tensor_parallel_size: int = 2 max_batch_size: int = 4 max_input_len: int = 128 max_output_len: int = 32 + class RooterArgsClass(BaseModel): """Config for Rooter""" + max_total_token_num: int = 42 batch_max_tokens: int = 42 eos_id: int = 0 @@ -22,6 +26,7 @@ class RooterArgsClass(BaseModel): log_stats_interval: int = 10 model: str + class RayInitConfig(BaseModel): """All-together configs without app router config""" diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 26d93eb1f14a..30717a915e3b 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,13 +1,14 @@ import time from typing import List +from .dynamic_batching.get_tokenizer import get_tokenizer from .dynamic_batching.infer_batch import InferBatch from .dynamic_batching.io_struct import Batch, Req from .dynamic_batching.req_queue import ReqQueue from .dynamic_batching.sampling_params import SamplingParams from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine -from .dynamic_batching.get_tokenizer import get_tokenizer + class DynamicBatchManager: def __init__( @@ -45,7 +46,7 @@ def __init__( self.has_wait_tokens = 0 self.max_wait_tokens = 10 self.model = model - + self.stats_tool = Stats(log_stats, log_stats_interval) self.mem_usage_interval = log_stats_interval * 2 self.tokenizer = get_tokenizer(tokenizer_name=self.model) @@ -65,13 +66,11 @@ def add_input(self, request_id, sampling_params, prompts): prompt_ids = self.tokenizer.encode(prompts) prompt_len = len(prompt_ids) if prompt_len > self.engine.max_input_len: - raise ValueError( - f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}" - ) + raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}") sampling_params.stop_sentences_to_token_ids(self.tokenizer) self.add_req(prompt_ids, sampling_params, request_id, prompts) return - + def abort(self, request_id): if self.running_batch is not None: for req in self.running_batch.reqs: @@ -89,7 +88,7 @@ def loop_for_fwd(self): The main loop for a dynamic batching process. """ counter_count = 0 - #self.running_batch is not None or self.req_queue.waiting_req_list + # self.running_batch is not None or self.req_queue.waiting_req_list while self.running_batch is not None or self.req_queue.waiting_req_list: yield from self._step() counter_count += 1 @@ -136,13 +135,13 @@ def _step(self): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 - + else: self.stats_tool.count_output_tokens(self.running_batch) yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 - + return def _init_batch(self, batch: Batch, dtype="fp16"): @@ -179,7 +178,7 @@ def _prefill_batch(self, batch): self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) yield from self._handle_finish_req(batch, has_new_finished_req) - + # delete finished reqs def _decode_batch(self, batch: Batch): @@ -222,14 +221,13 @@ def _remove_batch(self, batch): def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: - finished_reqs=batch.filter_finished() + finished_reqs = batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) yield from self._output_process(finished_reqs) - def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): self.running_batch = None @@ -253,13 +251,14 @@ def clean_up(self): # this logic should be implemented in the future. pass - def generate(self,prompts,sampling_params,request_id): + def generate(self, prompts, sampling_params, request_id): """ Generate the output of a request. """ - self.add_input(request_id,sampling_params,prompts) + self.add_input(request_id, sampling_params, prompts) return self.loop_for_fwd() + def start_dynamic_batching(args, tp_engine, waiting_req_list): try: batch_manager = DynamicBatchManager( diff --git a/tests/test_infer/test_dynamic_batching/config.yaml b/tests/test_infer/test_dynamic_batching/config.yaml index 0129f036a00f..c31ae8c5fadb 100644 --- a/tests/test_infer/test_dynamic_batching/config.yaml +++ b/tests/test_infer/test_dynamic_batching/config.yaml @@ -1,5 +1,5 @@ engine_config: - model: /home/lccd/share/model_data/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348 + model: MODEL_PATH tensor_parallel_size: 2 max_batch_size: 4 max_input_len: 128 @@ -12,4 +12,4 @@ router_config: eos_id: 0 disable_log_stats: False log_stats_interval: 10 - model: /home/lccd/share/model_data/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348 + model: MODEL_PATH diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index 9bf5ff68b6ae..09f41ba137de 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -1,9 +1,11 @@ +import asyncio import os import uuid -from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig + from colossalai.inference.dynamic_batching.ray_dist_init import Driver +from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -import asyncio + def test_ray_dist(path: str): print(f"Using yaml file {path}") @@ -15,28 +17,28 @@ def test_ray_dist(path: str): model = engine_config.model if model is None or not os.path.exists(model): raise ValueError("Model path not provided or invalid path!") - + driver = Driver(router_config=router_config, engine_config=engine_config) - prompt = 'Introduce some landmarks in Beijing' - + prompt = "Introduce some landmarks in Beijing" + request_id = str(uuid.uuid4().hex) - + sampling_params = SamplingParams() - + async def get_result(request_id, prompt, sampling_params): return await driver.async_generate(request_id, prompt, sampling_params) - + for test_async in [True, False]: - if test_async: + if test_async: print("test_async: ", test_async) - result = asyncio.run(get_result(request_id, prompt, sampling_params)) + result = asyncio.run(get_result(request_id, prompt, sampling_params)) print("result: ", result) else: print("test_async: ", test_async) - result = driver.generate(request_id, prompt, sampling_params) + result = driver.generate(request_id, prompt, sampling_params) print("result: ", result) - - + + if __name__ == "__main__": path = "config.yaml" - test_ray_dist(path) \ No newline at end of file + test_ray_dist(path) From 8c9ad51484064055c7d8262d5455f641860cbd72 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 16:10:42 +0800 Subject: [PATCH 10/13] fix bugs about No module named 'pydantic' in ci test --- requirements/requirements-test.txt | 1 + requirements/requirements.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 467f83610eb0..e22c1d1a5127 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -18,4 +18,5 @@ SentencePiece ninja flash_attn==2.0.5 datasets +pydantic #auto-gptq now not support torch1.12 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 9aa5f2822e40..421784f3de87 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -11,3 +11,4 @@ ninja torch>=1.12 safetensors einops +pydantic From 8d0cc6b51a8690c1fca0f42b94a5c5dcf46ba70b Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 16:23:03 +0800 Subject: [PATCH 11/13] fix bugs in ci test --- requirements/requirements-test.txt | 1 + requirements/requirements.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index e22c1d1a5127..f54b13c7e43c 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -19,4 +19,5 @@ ninja flash_attn==2.0.5 datasets pydantic +ray #auto-gptq now not support torch1.12 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 421784f3de87..8a4b0f1a0ffd 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -12,3 +12,4 @@ torch>=1.12 safetensors einops pydantic +ray From acdd751a2fd080b6c0d61f538a686513fe7cb818 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 16:50:49 +0800 Subject: [PATCH 12/13] fix bugs in ci test --- .../test_dynamic_batching/test_ray_dist.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index 09f41ba137de..76e47c7eabd3 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -5,6 +5,9 @@ from colossalai.inference.dynamic_batching.ray_dist_init import Driver from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +import colossalai +import pytest +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn def test_ray_dist(path: str): @@ -16,8 +19,7 @@ def test_ray_dist(path: str): engine_config = config.engine_config_data model = engine_config.model if model is None or not os.path.exists(model): - raise ValueError("Model path not provided or invalid path!") - + return driver = Driver(router_config=router_config, engine_config=engine_config) prompt = "Introduce some landmarks in Beijing" @@ -38,6 +40,16 @@ async def get_result(request_id, prompt, sampling_params): result = driver.generate(request_id, prompt, sampling_params) print("result: ", result) +def check_dynamic_batching_manager(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + test_ray_dist() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_dynamic_batching_manager(): + spawn(check_dynamic_batching_manager, 1) if __name__ == "__main__": path = "config.yaml" From 8a761bdad4f3907f3fe42c680c1edb15c8ad342a Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 17:38:08 +0800 Subject: [PATCH 13/13] fix bugs in ci test --- .../test_dynamic_batching/test_ray_dist.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index 76e47c7eabd3..4cf9881f41dc 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -9,8 +9,9 @@ import pytest from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn +PATH = "config.yaml" -def test_ray_dist(path: str): +def run_ray_dist(path: str): print(f"Using yaml file {path}") if not os.path.exists(path): raise FileNotFoundError(f"Invalid yaml file path {path}") @@ -40,17 +41,16 @@ async def get_result(request_id, prompt, sampling_params): result = driver.generate(request_id, prompt, sampling_params) print("result: ", result) -def check_dynamic_batching_manager(rank, world_size, port): +def check_ray_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - test_ray_dist() + run_ray_dist(PATH) @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() -def test_dynamic_batching_manager(): - spawn(check_dynamic_batching_manager, 1) +def test_ray_dist(): + spawn(check_ray_dist, 1) if __name__ == "__main__": - path = "config.yaml" - test_ray_dist(path) + test_ray_dist()