diff --git a/colossalai/inference/dynamic_batching/get_tokenizer.py b/colossalai/inference/dynamic_batching/get_tokenizer.py new file mode 100644 index 000000000000..af1f26848b3a --- /dev/null +++ b/colossalai/inference/dynamic_batching/get_tokenizer.py @@ -0,0 +1,34 @@ +from transformers import AutoTokenizer + +_FAST_LLAMA_TOKENIZER = "/home/lccd/share/llama-tokenizer" + + +def get_tokenizer( + tokenizer=None, + tokenizer_name: str = "", + trust_remote_code: bool = False, + use_fast: bool = True, +): + if tokenizer is not None: + tokenizer = tokenizer + else: + if "llama" in tokenizer_name.lower() and use_fast == True: + print( + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai." + ) + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) + except TypeError: + use_fast = False + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) + return tokenizer diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 2b2739f0ae90..9faaad6f111e 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -4,7 +4,7 @@ class Req: - def __init__(self, request_id, prompt_ids, sample_params: SamplingParams): + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str): self.request_id = request_id self.prompt_ids = prompt_ids self.input_len = len(prompt_ids) @@ -14,6 +14,7 @@ def __init__(self, request_id, prompt_ids, sample_params: SamplingParams): self.output_metadata_list = [] self.has_generate_finished = False self.aborted = False + self.prompts = prompts def to_rpc_obj(self): return { @@ -36,7 +37,11 @@ def stop_sequences_matched(self): if self.sample_params.stop_sequences is not None: for stop_token_ids in self.sample_params.stop_sequences: stop_len = len(stop_token_ids) - if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)): + if ( + stop_len > 0 + and len(self.output_ids) >= stop_len + and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)) + ): return True return False @@ -102,17 +107,21 @@ def mark_finished_req(self, eos_id): has_new_finish = True return has_new_finish - def filter_finished(self): + def filter_finished(self) -> List[Req]: """ Filter finished requests from the batch, the finished ones will be removed from 'reqs'. """ # TODO: the logic of return should be defined here. unfinished_req = [] + finished_req = [] for req in self.reqs: if not req.has_generate_finished: unfinished_req.append(req) + else: + finished_req.append(req) self.reqs = unfinished_req self.id_to_reqs = {req.request_id: req for req in self.reqs} + return finished_req def is_clear(self): return len(self.reqs) == 0 diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py new file mode 100644 index 000000000000..a40a00e2666c --- /dev/null +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -0,0 +1,164 @@ +import asyncio +import logging +import os +from typing import List + +import ray +import ray.util.collective as collective +import torch +from transformers import AutoModelForCausalLM + +import colossalai +from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer +from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.manager import start_dynamic_batching +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import free_port + +ray_serve_logger = logging.getLogger("ray.serve") + + +def log_cuda_info(scope_name: str): + ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}") + ray_serve_logger.info( + f" {scope_name}: CUDA_VISIBLE_DEVICES: {os.getenv('CUDA_VISIBLE_DEVICES', 'NO DEVICES FOUND!')}" + ) + if torch.cuda.is_available(): + ray_serve_logger.info( + f" {scope_name}: cuda current_device: {torch.cuda.current_device()}, cuda device count: {torch.cuda.device_count()}" + ) + else: + ray_serve_logger.info(f" {scope_name}: cuda is not available!") + + +@ray.remote(num_gpus=1) +class Worker: + def __init__( + self, + model_path: str, + tensor_parallel_size: int, + max_batch_size: int, + max_input_len: int, + max_output_len: int, + router_config: RooterArgsClass, + ): + log_cuda_info("Worker.init") + self.tensor_parallel_size = tensor_parallel_size + self.model_path = model_path + self.max_batch_size = max_batch_size + self.max_input_len = max_input_len + self.max_output_len = max_output_len + self.router_config = router_config + + def setup(self, world_size, rank, port): + # initialize a ray collective group, otherwise colossalai distributed env won't be built successfully + collective.init_collective_group(world_size, rank, "nccl", "default") + # initialize and set distributed environment + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..") + log_cuda_info("Worker.setup") + + # Load model + self.tokenizer = get_tokenizer(tokenizer_name=self.model_path) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16 + ) + + shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True) + self.infer_engine = TPInferEngine( + self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len + ) + self.start_dynamic_batching = start_dynamic_batching(self.router_config, self.infer_engine, []) + + return True + + def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> str: + ray_serve_logger.info(f"text: {prompt}") + + results_generator = self.start_dynamic_batching.generate(prompt, sampling_params, request_id) + + final_output = None + for request_output in results_generator: + final_output = request_output + + assert final_output is not None + ray_serve_logger.info(f"Generated text: {final_output}") + return final_output + + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): + self.start_dynamic_batching.add_input(request_id, sampling_params, prompt) + + def abort(self, request_id: str): + self.start_dynamic_batching.abort(request_id) + + def step(self): + self.start_dynamic_batching._step() + + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): + self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt) + + +class Driver: + def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClass): + log_cuda_info("Driver:init") + model_path = engine_config.model + tensor_parallel_size = engine_config.tensor_parallel_size + + self.num_workers = tensor_parallel_size + self.workers = [] + init_rets = [] + + # Just grab a free port on localhost + # NOTE workers in this communication group listen to the same port + available_port = free_port() + + for i in range(self.num_workers): + worker_name = "worker_idx_{}".format(i) + w = Worker.options(name=worker_name).remote( + model_path, + self.num_workers, + engine_config.max_batch_size, + engine_config.max_input_len, + engine_config.max_output_len, + router_config, + ) + self.workers.append(w) + init_rets.append(w.setup.remote(self.num_workers, i, available_port)) + _options = { + "group_name": "default_driver", + "world_size": self.num_workers, + "ranks": [i for i in range(self.num_workers)], + "backend": "nccl", + } + collective.create_collective_group(self.workers, **_options) + _ = ray.get(init_rets) + + # set batch wait delay in seconds and maximum number of sequences in a batch + def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): + results = ray.get([w.generate.remote(request_id, prompt, sampling_params) for w in self.workers]) + text_res = results[0] # get any one of the copies + return text_res + + async def async_generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): + all_outputs = [] + for worker in self.workers: + all_outputs.append(worker.generate.remote(request_id, prompt, sampling_params)) + all_outputs = await asyncio.gather(*all_outputs) + text_res = all_outputs[0] # get any one of the copies + return text_res + + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): + ray.get([w.add_input.remote(request_id, sampling_params, prompt) for w in self.workers]) + + def abort(self, request_id: str): + ray.get([w.abort.remote(request_id) for w in self.workers]) + + def step(self): + ray.get([w._step.remote() for w in self.workers]) + + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): + ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) diff --git a/colossalai/inference/dynamic_batching/ray_init_config.py b/colossalai/inference/dynamic_batching/ray_init_config.py new file mode 100644 index 000000000000..471f07330aec --- /dev/null +++ b/colossalai/inference/dynamic_batching/ray_init_config.py @@ -0,0 +1,58 @@ +import logging + +import yaml +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +class EngineArgsClass(BaseModel): + """Config for Engine""" + + model: str + tensor_parallel_size: int = 2 + max_batch_size: int = 4 + max_input_len: int = 128 + max_output_len: int = 32 + + +class RooterArgsClass(BaseModel): + """Config for Rooter""" + + max_total_token_num: int = 42 + batch_max_tokens: int = 42 + eos_id: int = 0 + disable_log_stats: bool = False + log_stats_interval: int = 10 + model: str + + +class RayInitConfig(BaseModel): + """All-together configs without app router config""" + + engine_config_data: EngineArgsClass + router_config_data: RooterArgsClass + + @classmethod + def from_yaml_path(cls, path: str): + try: + with open(path, "r") as yaml_file: + try: + config = yaml.safe_load(yaml_file) + # serve deployment config + engine_config = config.get("engine_config", {}) + router_config = config.get("router_config", {}) + + return cls( + engine_config_data=engine_config, + router_config_data=router_config, + ) + except yaml.YAMLError as e: + logger.error(f"An Error occurred when parsing yaml: {e}") + raise + except FileNotFoundError: + logger.error(f"The file '{path}' does not exist!") + raise + except OSError as e: + logger.error(f"An Error occurred: {e}") + raise diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 72f77406789f..30717a915e3b 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,6 +1,7 @@ import time from typing import List +from .dynamic_batching.get_tokenizer import get_tokenizer from .dynamic_batching.infer_batch import InferBatch from .dynamic_batching.io_struct import Batch, Req from .dynamic_batching.req_queue import ReqQueue @@ -16,6 +17,7 @@ def __init__( max_total_token_num, batch_max_tokens, eos_id, + model, log_stats=True, log_stats_interval=10, running_batch: Batch = None, @@ -27,6 +29,7 @@ def __init__( batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine eos_id : The end token of a seq + model: the model weight dir path, the app will load config, weights and tokenizer from this dir log_stats : whether to log stats log_stats_interval : log stats interval running_batch : running batch @@ -42,18 +45,32 @@ def __init__( self.eos_id = eos_id self.has_wait_tokens = 0 self.max_wait_tokens = 10 - + self.model = model + self.stats_tool = Stats(log_stats, log_stats_interval) self.mem_usage_interval = log_stats_interval * 2 + self.tokenizer = get_tokenizer(tokenizer_name=self.model) - def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str): + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str): """ Add new request to req queue, during initialization all requests are held in waiting list. """ - req = Req(request_id, prompt_ids, sampling_params) + req = Req(request_id, prompt_ids, sampling_params, prompts) self.req_queue.append(req) return + def add_input(self, request_id, sampling_params, prompts): + """ + Encode and Add new input to req queue. support one sequence input for now. + """ + prompt_ids = self.tokenizer.encode(prompts) + prompt_len = len(prompt_ids) + if prompt_len > self.engine.max_input_len: + raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}") + sampling_params.stop_sentences_to_token_ids(self.tokenizer) + self.add_req(prompt_ids, sampling_params, request_id, prompts) + return + def abort(self, request_id): if self.running_batch is not None: for req in self.running_batch.reqs: @@ -71,8 +88,9 @@ def loop_for_fwd(self): The main loop for a dynamic batching process. """ counter_count = 0 + # self.running_batch is not None or self.req_queue.waiting_req_list while self.running_batch is not None or self.req_queue.waiting_req_list: - self._step() + yield from self._step() counter_count += 1 if self.running_batch is not None: if counter_count % self.mem_usage_interval == 0: @@ -97,14 +115,14 @@ def _step(self): if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch - self._prefill_batch(self.running_batch) + yield from self._prefill_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) - self._decode_batch(self.running_batch) + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 return @@ -112,14 +130,15 @@ def _step(self): new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) - self._prefill_batch(new_mini_batch) + yield from self._prefill_batch(new_mini_batch) if not new_mini_batch.is_clear(): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 + else: self.stats_tool.count_output_tokens(self.running_batch) - self._decode_batch(self.running_batch) + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 @@ -158,7 +177,8 @@ def _prefill_batch(self, batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - self._handle_finish_req(batch, has_new_finished_req) + yield from self._handle_finish_req(batch, has_new_finished_req) + # delete finished reqs def _decode_batch(self, batch: Batch): @@ -169,7 +189,7 @@ def _decode_batch(self, batch: Batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - self._handle_finish_req(batch, has_new_finished_req) + yield from self._handle_finish_req(batch, has_new_finished_req) def _filter_batch(self, batch: Batch): batch_id = batch.batch_id @@ -201,11 +221,12 @@ def _remove_batch(self, batch): def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: - batch.filter_finished() + finished_reqs = batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) + yield from self._output_process(finished_reqs) def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): @@ -218,26 +239,41 @@ def _add_token_id_to_req(self, batch: Batch, req_ans): req.output_metadata_list.append(new_gen_metadata) return + def _output_process(self, finished_reqs: List[Req]): + """ + Process the output of a batch. + """ + for req in finished_reqs: + output = self.tokenizer.decode(req.output_ids) + yield req.prompts + output + def clean_up(self): # this logic should be implemented in the future. pass + def generate(self, prompts, sampling_params, request_id): + """ + Generate the output of a request. + """ + self.add_input(request_id, sampling_params, prompts) + return self.loop_for_fwd() + def start_dynamic_batching(args, tp_engine, waiting_req_list): - # try: - batch_manager = DynamicBatchManager( - tp_engine=tp_engine, - max_total_token_num=args.max_total_token_num, - batch_max_tokens=args.batch_max_tokens, - eos_id=args.eos_id, - log_stats=not args.disable_log_stats, - log_stats_interval=args.log_stats_interval, - waiting_req_list=waiting_req_list, - ) - - # except Exception: - # batch_manager.clean_up() - # raise - - batch_manager.loop_for_fwd() - return + try: + batch_manager = DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + model=args.model, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + except Exception: + batch_manager.clean_up() + raise + + return batch_manager diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 467f83610eb0..f54b13c7e43c 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -18,4 +18,6 @@ SentencePiece ninja flash_attn==2.0.5 datasets +pydantic +ray #auto-gptq now not support torch1.12 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 9aa5f2822e40..8a4b0f1a0ffd 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -11,3 +11,5 @@ ninja torch>=1.12 safetensors einops +pydantic +ray diff --git a/tests/test_infer/test_dynamic_batching/config.yaml b/tests/test_infer/test_dynamic_batching/config.yaml new file mode 100644 index 000000000000..c31ae8c5fadb --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/config.yaml @@ -0,0 +1,15 @@ +engine_config: + model: MODEL_PATH + tensor_parallel_size: 2 + max_batch_size: 4 + max_input_len: 128 + max_output_len: 32 +# config for app router deployment +# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig? +router_config: + max_total_token_num: 42 + batch_max_tokens: 42 + eos_id: 0 + disable_log_stats: False + log_stats_interval: 10 + model: MODEL_PATH diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py new file mode 100644 index 000000000000..4cf9881f41dc --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -0,0 +1,56 @@ +import asyncio +import os +import uuid + +from colossalai.inference.dynamic_batching.ray_dist_init import Driver +from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +import colossalai +import pytest +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +PATH = "config.yaml" + +def run_ray_dist(path: str): + print(f"Using yaml file {path}") + if not os.path.exists(path): + raise FileNotFoundError(f"Invalid yaml file path {path}") + config = RayInitConfig.from_yaml_path(path) + router_config = config.router_config_data + engine_config = config.engine_config_data + model = engine_config.model + if model is None or not os.path.exists(model): + return + driver = Driver(router_config=router_config, engine_config=engine_config) + prompt = "Introduce some landmarks in Beijing" + + request_id = str(uuid.uuid4().hex) + + sampling_params = SamplingParams() + + async def get_result(request_id, prompt, sampling_params): + return await driver.async_generate(request_id, prompt, sampling_params) + + for test_async in [True, False]: + if test_async: + print("test_async: ", test_async) + result = asyncio.run(get_result(request_id, prompt, sampling_params)) + print("result: ", result) + else: + print("test_async: ", test_async) + result = driver.generate(request_id, prompt, sampling_params) + print("result: ", result) + +def check_ray_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_ray_dist(PATH) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_ray_dist(): + spawn(check_ray_dist, 1) + +if __name__ == "__main__": + test_ray_dist()