From e0757c31fb4491fef908d897846d47b030fd56f1 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Wed, 11 Oct 2023 17:52:52 +0800 Subject: [PATCH 01/32] [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 --- .../inference/dynamic_batching/__init__.py | 0 .../inference/dynamic_batching/infer_batch.py | 346 ++++++++++++++++++ .../inference/dynamic_batching/io_struct.py | 149 ++++++++ .../inference/dynamic_batching/req_queue.py | 71 ++++ .../dynamic_batching/sampling_params.py | 82 +++++ .../inference/dynamic_batching/stats.py | 43 +++ colossalai/inference/manager.py | 243 ++++++++++++ .../inference/tensor_parallel/engine.py | 115 +++++- .../tensor_parallel/modeling/_utils.py | 2 +- .../tensor_parallel/modeling/llama.py | 45 +-- colossalai/kernel/triton/__init__.py | 1 - .../kernel/triton/copy_kv_cache_dest.py | 1 - tests/kit/model_zoo/transformers/llama.py | 6 +- .../test_dynamic_batching_manager.py | 94 +++++ .../test_dynamic_batching/test_forward.py | 70 ++++ tests/test_infer/test_llama_infer.py | 1 - 16 files changed, 1221 insertions(+), 48 deletions(-) create mode 100644 colossalai/inference/dynamic_batching/__init__.py create mode 100644 colossalai/inference/dynamic_batching/infer_batch.py create mode 100644 colossalai/inference/dynamic_batching/io_struct.py create mode 100644 colossalai/inference/dynamic_batching/req_queue.py create mode 100644 colossalai/inference/dynamic_batching/sampling_params.py create mode 100644 colossalai/inference/dynamic_batching/stats.py create mode 100644 colossalai/inference/manager.py create mode 100644 tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py create mode 100644 tests/test_infer/test_dynamic_batching/test_forward.py diff --git a/colossalai/inference/dynamic_batching/__init__.py b/colossalai/inference/dynamic_batching/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/dynamic_batching/infer_batch.py b/colossalai/inference/dynamic_batching/infer_batch.py new file mode 100644 index 000000000000..826272db3e11 --- /dev/null +++ b/colossalai/inference/dynamic_batching/infer_batch.py @@ -0,0 +1,346 @@ +import collections +from dataclasses import dataclass +from typing import Dict, List , Tuple + +import numpy as np +import torch + +from colossalai.inference.tensor_parallel import MemoryManager + +# make batch infer state an attr of InferBatch + + +class InferSamplingParams: + def __init__( + self, + do_sample: bool = False, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + vocab_size: int = -1, + ) -> None: + self.do_sample = do_sample + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + if self.top_k == -1: + self.top_k = vocab_size + return + + +@dataclass +class InferBatch: + batch_id: int + requests: List + requests_idx_mapping: Dict[int, int] + + input_ids: torch.Tensor + + all_input_ids: List[List[int]] + input_lengths: List[int] + + out_token_id_counts: List + sampling_param_list: List[InferSamplingParams] + + nopad_total_token_num: int + nopad_max_len_in_batch: int + nopad_b_loc: torch.Tensor + nopad_b_start_loc: torch.Tensor + nopad_b_seq_len: torch.Tensor + cache_manager: MemoryManager + max_total_len: int + + @classmethod + @torch.no_grad() + def init_batch( + cls, + batch_id, + requests, + dtype: torch.dtype, + device: torch.device, + cache_manager: MemoryManager, + vocab_size: int, + max_total_len: int, + ) -> 'InferBatch': + input_lengths = [] + all_input_ids = [] + requests_idx_mapping = {} + + out_token_id_counts = [] + sampling_param_list = [] + + nopad_total_token_num = 0 + nopad_max_len_in_batch = 0 + nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda") + # to avoid memory leak , we pre-allocate 12 more space for each batch. + nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda") + for i, r in enumerate(requests): + # request id -> idx in list mapping + requests_idx_mapping[r["request_id"]] = i + + tokenized_input = r["input_id"] + + input_length = len(tokenized_input) + input_lengths.append(input_length) + all_input_ids.append(tokenized_input) + out_token_id_counts.append(collections.defaultdict(int)) + + # postprocessor + sampling_param = r["sampling_param"] + sampling_param["vocab_size"] = vocab_size + sampling_param_list.append(InferSamplingParams(**sampling_param)) + + nopad_total_token_num += input_length + nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length) + + nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda") + nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] + + if len(requests) > 1: + input_ids = np.concatenate(all_input_ids, dtype=np.int64) + else: + input_ids = all_input_ids[0] + + # Create tensors on device + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + + return cls( + batch_id=batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=out_token_id_counts, + sampling_param_list=sampling_param_list, + cache_manager=cache_manager, + max_total_len=max_total_len, + ) + + @torch.no_grad() + def free_self(self) -> None: + """ + Free the memory of the InferBatch itself + """ + remove_index = [] + for idx in range(len(self)): + remove_index.append( + self.nopad_b_loc[ + idx, + (self.nopad_max_len_in_batch - 1) + - (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1), + ] + ) + remove_index = torch.cat(remove_index, dim=-1) + self.cache_manager.free(remove_index) + + + @torch.no_grad() + def filter(self, request_ids: List[int]) -> 'InferBatch': + """ + Filter finished batch and return a new InferBatch with left ones. + """ + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + requests_idx_mapping = {} + indices = [] + requests = [] + all_input_ids = [] + input_lengths = [] + nopad_total_token_num = 0 + nopad_max_len_in_batch = 0 + nopad_b_loc = torch.empty((len(request_ids), self.max_total_len + 12), dtype=torch.long, device="cuda") + nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda") + nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda") + + left_idx = [] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + left_idx.append(idx) + + left_idx_set = set(left_idx) + remove_index = [] + for idx in range(len(self)): + if idx not in left_idx_set: + remove_index.append( + self.nopad_b_loc[ + idx, + (self.nopad_max_len_in_batch - 1) + - (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1), + ] + ) + remove_index = torch.cat(remove_index, dim=-1) + self.cache_manager.free(remove_index) + + nopad_max_len_in_batch = 0 + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + indices.append(idx) + + nopad_b_seq_len[:] = self.nopad_b_seq_len[indices] + nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item() + nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] + nopad_total_token_num = torch.sum(nopad_b_seq_len).item() + + nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[ + indices, + (self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1) : (self.nopad_max_len_in_batch - 1), + ] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + requests.append(self.requests[idx]) + all_input_ids.append(self.all_input_ids[idx]) + input_lengths.append(self.input_lengths[idx]) + + input_ids = self.input_ids[indices] + + return InferBatch( + batch_id=self.batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices], + sampling_param_list=[self.sampling_param_list[_i] for _i in indices], + cache_manager=self.cache_manager, + max_total_len=self.max_total_len, + ) + + @classmethod + @torch.no_grad() + def merge(cls, batch1, batch2) -> 'InferBatch': + """ + Return megerd new InferBatch + """ + requests = batch1.requests + batch2.requests + requests_idx_mapping = {} + new_batch_size = len(batch1) + len(batch2) + + input_ids = batch1.input_ids.new_empty(new_batch_size) + all_input_ids = [] + input_lengths = [] + out_token_id_counts = [] + sampling_param_list = [] + + cumulative_batch_size = 0 + nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num + nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2.nopad_max_len_in_batch) + max_total_len = max(batch1.max_total_len, batch2.max_total_len) + nopad_b_loc = torch.empty((new_batch_size, batch1.max_total_len + 12), dtype=torch.long, device="cuda") + nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") + nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") + nopad_start_loc_len_temp = 0 + batches = [batch1, batch2] + for i, batch in enumerate(batches): + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + cumulative_batch_size + start_index = cumulative_batch_size + end_index = cumulative_batch_size + len(batch) + input_ids[start_index:end_index] = batch.input_ids + nopad_b_seq_len[start_index:end_index] = batch.nopad_b_seq_len + nopad_b_start_loc[start_index:end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp + nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1] + nopad_b_loc[ + start_index:end_index, + nopad_max_len_in_batch - batch.nopad_max_len_in_batch : nopad_max_len_in_batch - 1, + ] = batch.nopad_b_loc[:, : batch.nopad_max_len_in_batch - 1] + + all_input_ids.extend(batch.all_input_ids) + + input_lengths.extend(batch.input_lengths) + out_token_id_counts.extend(batch.out_token_id_counts) + sampling_param_list.extend(batch.sampling_param_list) + # Update + cumulative_batch_size += len(batch) + + nopad_b_loc[:, nopad_max_len_in_batch - 1] = ( + nopad_total_token_num - new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device="cuda") + ) + return InferBatch( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=out_token_id_counts, + sampling_param_list=sampling_param_list, + cache_manager=batches[0].cache_manager, + max_total_len=max_total_len, + ) + + def __len__(self): + return len(self.requests) + + def get_post_sample_tensors(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + presence_penalties: List[float] = [] + frequency_penalties: List[float] = [] + temperatures: List[float] = [] + top_ps: List[float] = [] + top_ks: List[int] = [] + p_token_ids: List[int] = [] + p_token_counts: List[int] = [] + p_seq_len: List[int] = [ + 0, + ] + p_max_len_in_batch: int = 0 + for i, id_to_count in enumerate(self.out_token_id_counts): + sample_param = self.sampling_param_list[i] + presence_penalties.append(sample_param.presence_penalty) + frequency_penalties.append(sample_param.frequency_penalty) + temperatures.append(sample_param.temperature) + top_ps.append(sample_param.top_p) + top_ks.append(sample_param.top_k) + + for token_id, count in id_to_count.items(): + p_token_ids.append(token_id) + p_token_counts.append(count) + p_seq_len.append(len(id_to_count)) + p_max_len_in_batch = max(p_max_len_in_batch, len(id_to_count)) + + presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda") + frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda") + temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda") + top_ps = torch.tensor(top_ps, dtype=torch.float, device="cuda") + top_ks = torch.tensor(top_ks, dtype=torch.int32, device="cuda") + p_token_ids = torch.tensor(p_token_ids, dtype=torch.int32, device="cuda") + p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda") + p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda") + p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32) + return ( + presence_penalties, + frequency_penalties, + temperatures, + top_ps, + top_ks, + p_token_ids, + p_token_counts, + p_cumsum_seq_len, + p_max_len_in_batch, + ) diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py new file mode 100644 index 000000000000..2b2739f0ae90 --- /dev/null +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -0,0 +1,149 @@ +from typing import Dict, List, Tuple + +from .sampling_params import SamplingParams + + +class Req: + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams): + self.request_id = request_id + self.prompt_ids = prompt_ids + self.input_len = len(prompt_ids) + self.max_output_len = sample_params.max_new_tokens + self.sample_params = sample_params + self.output_ids = [] + self.output_metadata_list = [] + self.has_generate_finished = False + self.aborted = False + + def to_rpc_obj(self): + return { + "request_id": self.request_id, + "input_id": self.prompt_ids, + "output_len": self.max_output_len, + "sampling_param": self.sample_params.to_dict(), + } + + def to_req_detokenization_state(self): + out = ReqDetokenizationState( + self.request_id, self.prompt_ids, self.max_output_len, self.sample_params.ignore_eos + ) + if self.output_metadata_list: + out.gen_metadata.update(self.output_metadata_list[-1]) + return out + + def stop_sequences_matched(self): + # should we add stpp sequences to the sample params? + if self.sample_params.stop_sequences is not None: + for stop_token_ids in self.sample_params.stop_sequences: + stop_len = len(stop_token_ids) + if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)): + return True + return False + + def __repr__(self): + return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, " + + +class ReqDetokenizationState: + def __init__( + self, + request_id: str, + prompt_ids: List[int], + max_output_len: int, + ignore_eos: bool, + ) -> None: + self.request_id = request_id + self.prompt_ids = prompt_ids + self.output_ids = [] + self.output_tokens = [] + self.output_str = "" + self.sub_texts = [] + self.current_sub_text = [] + self.max_output_len = max_output_len + self.ignore_eos = ignore_eos + self.gen_metadata = {} + + +class Batch: + def __init__(self, batch_id, reqs: List[Req]): + self.batch_id = batch_id + self.reqs = reqs + self.id_to_reqs = {req.request_id: req for req in reqs} + + def input_tokens(self): + batch_input_tokens = 0 + for req in self.reqs: + batch_input_tokens += req.input_len + return batch_input_tokens + + def calcu_max_tokens(self): + tokens = 0 + for req in self.reqs: + tokens += req.input_len + req.max_output_len + return tokens + + def calcu_used_tokens(self): + tokens = 0 + for req in self.reqs: + tokens += req.input_len + len(req.output_ids) + return tokens + + def mark_finished_req(self, eos_id): + has_new_finish = False + for req in self.reqs: + if req.stop_sequences_matched(): + req.has_generate_finished = True + has_new_finish = True + if req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False: + req.has_generate_finished = True + has_new_finish = True + if len(req.output_ids) >= req.max_output_len or req.aborted: + req.has_generate_finished = True + has_new_finish = True + return has_new_finish + + def filter_finished(self): + """ + Filter finished requests from the batch, the finished ones will be removed from 'reqs'. + """ + # TODO: the logic of return should be defined here. + unfinished_req = [] + for req in self.reqs: + if not req.has_generate_finished: + unfinished_req.append(req) + self.reqs = unfinished_req + self.id_to_reqs = {req.request_id: req for req in self.reqs} + + def is_clear(self): + return len(self.reqs) == 0 + + def merge(self, mini_batch): + for _req in mini_batch.reqs: + self.reqs.append(_req) + self.id_to_reqs = {req.request_id: req for req in self.reqs} + return + + def __repr__(self): + return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, " + + def __len__(self): + return len(self.reqs) + + +class BatchTokenIdOut: + def __init__(self): + self.reqs_infs: List[ + Tuple[str, int, Dict, bool, bool] + ] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state] + + +class BatchStrOut: + def __init__(self): + self.reqs_infs: List[ + Tuple[str, str, Dict, bool, bool] + ] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state] + + +class AbortReq: + def __init__(self, req_id): + self.req_id = req_id diff --git a/colossalai/inference/dynamic_batching/req_queue.py b/colossalai/inference/dynamic_batching/req_queue.py new file mode 100644 index 000000000000..d9e9b6269cc4 --- /dev/null +++ b/colossalai/inference/dynamic_batching/req_queue.py @@ -0,0 +1,71 @@ +import uuid +from typing import List + +import numpy as np + +from .io_struct import Batch, Req + + +class ReqQueue: + def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size, waiting_req_list=[]) -> None: + self.max_total_tokens = max_total_tokens + assert batch_max_tokens is not None + self.batch_max_tokens = batch_max_tokens + self.running_max_req_size = running_max_req_size + self.waiting_req_list: List[Req] = waiting_req_list + + def append(self, req): + self.waiting_req_list.append(req) + return + + def _init_cache_list(self, current_batch: Batch): + if current_batch is not None: + self.cache_len_list = [ + (req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1) + for req in current_batch.reqs + ] + else: + self.cache_len_list = [] + + # @calculate_time(show=True, min_cost_ms=0.1) + def _can_add_new_req(self, req): + self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis + self.cache_len_list.sort(key=lambda x: -x[1]) + + left_out_len_array = np.array([e[1] for e in self.cache_len_list]) + # assert left_out_len_array.min() >= 0 + has_run_len_array = np.array([e[0] for e in self.cache_len_list]) + cum_run_len_array = np.cumsum(has_run_len_array) + size_array = np.arange(1, len(self.cache_len_list) + 1, 1) + + need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() + # NOTE: change here < to <= + return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size + + def generate_new_batch(self, current_batch: Batch = None): + if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size: + return None + self._init_cache_list(current_batch) + can_run_list = [] + new_batch_total_tokens = 0 + aborted_count = 0 + for req in self.waiting_req_list: + flag = self._can_add_new_req(req) + if req.aborted: + aborted_count += 1 + continue + if flag and new_batch_total_tokens + req.input_len <= self.batch_max_tokens: + can_run_list.append(req) + new_batch_total_tokens += req.input_len + else: + break + + if len(can_run_list) != 0: + new_batch = Batch(uuid.uuid4().hex, can_run_list) + self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] + return new_batch + else: + return None + + def __len__(self): + return self.waiting_req_list.__len__() diff --git a/colossalai/inference/dynamic_batching/sampling_params.py b/colossalai/inference/dynamic_batching/sampling_params.py new file mode 100644 index 000000000000..9a0ace4111dd --- /dev/null +++ b/colossalai/inference/dynamic_batching/sampling_params.py @@ -0,0 +1,82 @@ +"""Sampling parameters for text generation.""" +from typing import List, Optional, Union + +_SAMPLING_EPS = 1e-5 + + +class SamplingParams: + + def __init__( + self, + do_sample: bool = False, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, # -1 is for all + ignore_eos: bool = False, + max_new_tokens: int = 16, + stop_sequences: Optional[Union[str, List[str]]] = None # conditions to stop generation + ) -> None: + self.do_sample = do_sample + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.ignore_eos = ignore_eos + self.max_new_tokens = max_new_tokens + self.stop_sequences = stop_sequences + if self.do_sample == False: + self.temperature = 1.0 + self.top_p = 1.0 + self.top_k = 1 + if self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS: # temperature is too slow, change to greedy search + self.temperature = 1.0 + self.top_k = 1 + return + + def verify(self): + if self.presence_penalty < 0.0: + raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}") + if self.frequency_penalty < 0.0: + raise ValueError(f"frequency_penalty must >= 0.0, got {self.frequency_penalty}") + if self.temperature <= 0.0: + raise ValueError(f"temperature must > 0.0, got {self.temperature}") + if self.top_p <= 0.0 or self.top_p > 1.0: + raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}") + if self.top_k < -1 or self.top_k == 0: + raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.") + if self.max_new_tokens < 1: + raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.") + return + + def stop_sentences_to_token_ids(self, tokenizer): + if self.stop_sequences is None: + self.stop_sequences = [] + else: + if isinstance(self.stop_sequences, str): + self.stop_sequences = [self.stop_sequences] + new_stop_sequences = [] + for stop_str in self.stop_sequences: + stop_str_ids = tokenizer.encode(stop_str) + if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id + stop_str_ids = stop_str_ids[1:] + if len(stop_str_ids) > 0: + new_stop_sequences.append(stop_str_ids) + self.stop_sequences = new_stop_sequences + return + + def to_dict(self): + ret = {} + ret["do_sample"] = self.do_sample + ret["presence_penalty"] = self.presence_penalty + ret["frequency_penalty"] = self.frequency_penalty + ret["temperature"] = self.temperature + ret["top_p"] = self.top_p + ret["top_k"] = self.top_k + # if self.ignore_eos is not None: + # ret["ignore_eos"] = self.ignore_eos + # if self.max_tokens is not None: + # ret["max_tokens"] = self.max_tokens + return ret diff --git a/colossalai/inference/dynamic_batching/stats.py b/colossalai/inference/dynamic_batching/stats.py new file mode 100644 index 000000000000..6d34183f47c4 --- /dev/null +++ b/colossalai/inference/dynamic_batching/stats.py @@ -0,0 +1,43 @@ +import time + + +class Stats: + def __init__(self, log_status, log_stats_interval) -> None: + self.log_stats = log_status + self.log_stats_interval = log_stats_interval + self.last_log_time = time.time() + self.all_tokens = 0 + self.output_tokens = 0 + self.prompt_tokens = 0 + return + + def count_prompt_tokens(self, run_batch): + if self.log_stats: + tokens = run_batch.input_tokens() + self.prompt_tokens += tokens + self.all_tokens += tokens + return + + def count_output_tokens(self, run_batch): + if self.log_stats: + tokens = len(run_batch.reqs) + self.output_tokens += tokens + self.all_tokens += tokens + return + + def print_stats(self): + if not self.log_stats: + return + + now = time.time() + if now - self.last_log_time > self.log_stats_interval: + print( + f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n" + f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n" + f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s" + ) + self.all_tokens = 0 + self.output_tokens = 0 + self.prompt_tokens = 0 + self.last_log_time = now + return diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py new file mode 100644 index 000000000000..72f77406789f --- /dev/null +++ b/colossalai/inference/manager.py @@ -0,0 +1,243 @@ +import time +from typing import List + +from .dynamic_batching.infer_batch import InferBatch +from .dynamic_batching.io_struct import Batch, Req +from .dynamic_batching.req_queue import ReqQueue +from .dynamic_batching.sampling_params import SamplingParams +from .dynamic_batching.stats import Stats +from .tensor_parallel import TPInferEngine + + +class DynamicBatchManager: + def __init__( + self, + tp_engine: TPInferEngine, + max_total_token_num, + batch_max_tokens, + eos_id, + log_stats=True, + log_stats_interval=10, + running_batch: Batch = None, + waiting_req_list: List = [], + ): + """ + Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager + max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) + batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests + running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine + eos_id : The end token of a seq + log_stats : whether to log stats + log_stats_interval : log stats interval + running_batch : running batch + waiting_req_list : list of waiting requests, initialized before dynamic batch manager + """ + self.engine = tp_engine + self.max_total_token_num = max_total_token_num + running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2 + self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list) + # all the inputs should be put into req_queue: waiting req list + + self.running_batch: Batch = running_batch + self.eos_id = eos_id + self.has_wait_tokens = 0 + self.max_wait_tokens = 10 + + self.stats_tool = Stats(log_stats, log_stats_interval) + self.mem_usage_interval = log_stats_interval * 2 + + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str): + """ + Add new request to req queue, during initialization all requests are held in waiting list. + """ + req = Req(request_id, prompt_ids, sampling_params) + self.req_queue.append(req) + return + + def abort(self, request_id): + if self.running_batch is not None: + for req in self.running_batch.reqs: + if req.request_id == request_id: + req.has_generate_finished = True + req.aborted = True + for req in self.req_queue.waiting_req_list: + if req.request_id == request_id: + req.has_generate_finished = True + req.aborted = True + return + + def loop_for_fwd(self): + """ + The main loop for a dynamic batching process. + """ + counter_count = 0 + while self.running_batch is not None or self.req_queue.waiting_req_list: + self._step() + counter_count += 1 + if self.running_batch is not None: + if counter_count % self.mem_usage_interval == 0: + print( + "current batch size:", + len(self.running_batch.reqs), + "token used ratio:", + self.running_batch.calcu_used_tokens() / self.max_total_token_num, + ) + self.stats_tool.print_stats() + + if self.running_batch is None: + time.sleep(0.1) # 10ms + + def _step(self): + """ + Logic for handling requests + """ + + if self.running_batch is None: + new_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_batch is not None: + self.stats_tool.count_prompt_tokens(new_batch) + self.running_batch = new_batch + self._prefill_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens = 0 + return + + if self.has_wait_tokens < self.max_wait_tokens: + self.stats_tool.count_output_tokens(self.running_batch) + self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + return + else: + new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_mini_batch is not None: + self.stats_tool.count_prompt_tokens(new_mini_batch) + self._prefill_batch(new_mini_batch) + if not new_mini_batch.is_clear(): + self._merge_batch(self.running_batch, new_mini_batch) + self.running_batch.merge(new_mini_batch) + self.has_wait_tokens = 0 + else: + self.stats_tool.count_output_tokens(self.running_batch) + self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + + return + + def _init_batch(self, batch: Batch, dtype="fp16"): + reqs = [r.to_rpc_obj() for r in batch.reqs] + batch_id = batch.batch_id + + import torch + + if dtype == "fp16": + dtype = torch.float16 + else: + assert False, "error dtype" + + batch_data = InferBatch.init_batch( + batch_id, + reqs, + dtype, + torch.cuda.current_device(), + self.engine.cache_manager, + self.engine.model.config.vocab_size, + self.engine.max_input_len + self.engine.max_output_len, + ) + self.engine.cache[batch_id] = batch_data + + def _prefill_batch(self, batch): + """ + For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. + """ + self._init_batch(batch) + + # TODO: figure out if cache and batch id is needed + ans = self.engine._prefill_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id) + self._handle_finish_req(batch, has_new_finished_req) + # delete finished reqs + + def _decode_batch(self, batch: Batch): + """ + Decoding process + """ + ans = self.engine._decode_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id) + self._handle_finish_req(batch, has_new_finished_req) + + def _filter_batch(self, batch: Batch): + batch_id = batch.batch_id + req_id_list = [r.request_id for r in batch.reqs] + batch = self.engine.cache.pop(batch_id) + filter_batch = batch.filter(req_id_list) + del batch + self.engine.cache[batch_id] = filter_batch + + def _merge_batch(self, batch1, batch2): + """ + Merge new mini batch into running batch. + """ + batch1 = self.engine.cache.pop(batch1.batch_id) + batch2 = self.engine.cache.pop(batch2.batch_id) + + m_batch = InferBatch.merge(batch1, batch2) + self.engine.cache[batch1.batch_id] = m_batch + del batch1 + del batch2 + + def _remove_batch(self, batch): + """ + Remove finished batch. + """ + batch = self.engine.cache.pop(batch.batch_id) + batch.free_self() + del batch + + def _handle_finish_req(self, batch: Batch, has_new_finished_req): + if has_new_finished_req: + batch.filter_finished() + if batch.is_clear(): + self._remove_batch(batch) + else: + self._filter_batch(batch) + + def _filter_runing_batch(self): + if self.running_batch is not None and self.running_batch.is_clear(): + self.running_batch = None + + def _add_token_id_to_req(self, batch: Batch, req_ans): + for req_id, (new_token_id, new_gen_metadata) in req_ans.items(): + req = batch.id_to_reqs[req_id] + req.output_ids.append(new_token_id) + req.output_metadata_list.append(new_gen_metadata) + return + + def clean_up(self): + # this logic should be implemented in the future. + pass + + +def start_dynamic_batching(args, tp_engine, waiting_req_list): + # try: + batch_manager = DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + # except Exception: + # batch_manager.clean_up() + # raise + + batch_manager.loop_for_fwd() + return diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index d5ef37fee420..f7fb7a825694 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -1,7 +1,6 @@ from typing import Any, Callable, List, Optional, Union import torch -import torch.distributed as dist import torch.nn as nn from transformers import BloomForCausalLM, LlamaForCausalLM from transformers.generation import GenerationConfig @@ -14,6 +13,8 @@ from .batch_infer_state import BatchInferState from .kvcache_manager import MemoryManager +# from dynamic_batching.infer_batch import InferBatch + DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 _supported_models = [ @@ -90,6 +91,8 @@ def __init__( self.shard_config = shard_config self.model = None + self.cache = {} + # optimize the original model by sharding with ShardFormer self._optimize_model(model=model.to(device)) @@ -116,13 +119,15 @@ def _init_manager(self) -> None: def _post_init_gptq_buffer(self, model: nn.Module) -> None: from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear + HAS_GPTQ_CUDA = False try: from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() HAS_GPTQ_CUDA = True except ImportError: - warnings.warn('CUDA gptq is not installed') + warnings.warn("CUDA gptq is not installed") HAS_GPTQ_CUDA = False for name, submodule in model.named_modules(): @@ -130,8 +135,9 @@ def _post_init_gptq_buffer(self, model: nn.Module) -> None: self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8) if self.use_act_order: - self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures, - submodule.outfeatures) + self.max_inner_outer_dim = max( + self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures + ) self.bits = submodule.bits if not (HAS_GPTQ_CUDA and self.bits == 4): return @@ -141,15 +147,16 @@ def _post_init_gptq_buffer(self, model: nn.Module) -> None: max_input_len = self.max_input_len # The temp_state buffer is required to reorder X in the act-order case. # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim), - dtype=torch.float16, - device=torch.cuda.current_device()) - self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size), - dtype=torch.float16, - device=torch.cuda.current_device()) - - gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, - self.gptq_temp_dq_buffer) + self.gptq_temp_state_buffer = torch.zeros( + (max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device() + ) + self.gptq_temp_dq_buffer = torch.zeros( + (1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device() + ) + + gptq_cuda.prepare_buffers( + torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer + ) # Using the default from exllama repo here. matmul_recons_thd = 8 matmul_fused_remap = False @@ -270,7 +277,6 @@ def prepare_batch_state(self, inputs) -> BatchInferState: attention_mask = [attention_mask] if attention_mask is not None else attention_mask batch_size = len(input_ids_list) - seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") start_index = 0 @@ -304,6 +310,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState: batch_infer_state.past_key_values_len = 0 batch_infer_state.is_context_stage = True batch_infer_state.set_cache_manager(self.cache_manager) + return batch_infer_state @torch.no_grad() @@ -367,6 +374,86 @@ def _update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device) infer_state.seq_len += 1 + @torch.no_grad() + def forward(self, batch_id, is_prefill): + """ + Forward is used in Dynamic Batching Manager + """ + batch = self.cache.pop(batch_id) + + if is_prefill: + input_ = torch.tensor(batch.all_input_ids).cuda() + else: + input_ = batch.input_ids.reshape(len(batch), 1) + + batch_args = { + "batch_size": len(batch), + "max_len_in_batch": batch.nopad_max_len_in_batch, + "block_loc": batch.nopad_b_loc, + "start_loc": batch.nopad_b_start_loc, + "seq_len": batch.nopad_b_seq_len, + "cache_manager": batch.cache_manager, + "is_context_stage": is_prefill, + } + + infer_state = BatchInferState(**batch_args) + model = self.model + if isinstance(model, LlamaForCausalLM): + model = self.model.model + elif isinstance(model, BloomForCausalLM): + model = self.model.transformer + setattr(model, "infer_state", infer_state) + + output = self.model.forward(input_ids=input_) + logits = output.logits + # bsz, seq_len, vocab_size + prob_out = torch.softmax( + logits[ + :, + -1, + ], + dim=-1, + ).squeeze(1) + # prob_out: bsz, vocab_size + predict_ids = torch.argmax(prob_out, dim=-1, keepdim=True) + prob_out = torch.log(prob_out).detach().cpu().numpy() + predict_ids = predict_ids.detach().cpu().numpy() + # [ batch_size, 1 ] + + output_dict = {} + new_input_ids = [] + for i, (r, all_input_ids, next_token_id, next_token_logprob) in enumerate( + zip(batch.requests, batch.all_input_ids, predict_ids, prob_out) + ): + next_token_id = int(next_token_id) + next_token_logprob = next_token_logprob[next_token_id] + # all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device="cuda") + all_input_ids.append(next_token_id) + # all_input_ids_tensor = None + new_input_ids.append(next_token_id) + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] += 1 + batch.out_token_id_counts[i][next_token_id] += 1 + metadata = { + "id": int(next_token_id), + "logprob": float(next_token_logprob), + } + output_dict[r["request_id"]] = (int(next_token_id), metadata) + + batch.input_ids = torch.tensor(new_input_ids, dtype=torch.long).cuda() + batch.nopad_total_token_num += len(batch) + batch.nopad_max_len_in_batch += 1 + self.cache[batch.batch_id] = batch + return output_dict + + @torch.no_grad() + def _prefill_batch(self, batch_id): + return self.forward(batch_id, is_prefill=True) + + @torch.no_grad() + def _decode_batch(self, batch_id): + return self.forward(batch_id, is_prefill=False) + # might want to create a sequence pool # add a single request/sequence/input text at a time and record its length # In other words, store the actual length of input tokens representing a single input text diff --git a/colossalai/inference/tensor_parallel/modeling/_utils.py b/colossalai/inference/tensor_parallel/modeling/_utils.py index e476c3132538..068b64b4f829 100644 --- a/colossalai/inference/tensor_parallel/modeling/_utils.py +++ b/colossalai/inference/tensor_parallel/modeling/_utils.py @@ -45,7 +45,7 @@ def init_to_get_rotary(self, base=10000, use_elem=False): base = float(base) # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ - ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", None)) + ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None) if ntk_alpha is not None: ntk_alpha = float(ntk_alpha) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index a7661cee1128..958868a0974e 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -62,12 +62,11 @@ def llama_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): - batch_size = input_ids.shape[0] # input_ids.shape[0] - infer_state = self.infer_state return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") @@ -78,15 +77,12 @@ def llama_model_forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - # NOT READY FOR PRIME TIME - # dummy but work, revise it - past_key_values_length = infer_state.cache_manager.past_key_values_length - # past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length + # NOT READY FOR PRIME TIME + # dummy but work, revise it + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 # NOTE: differentiate with prefill stage # block_loc require different value-assigning method for two different stage @@ -106,23 +102,23 @@ def llama_model_forward( infer_state.decode_mem_index = alloc_mem[0] infer_state.decode_mem_start = alloc_mem[1] infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index else: print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) + print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}") infer_state.decode_is_contiguous = False alloc_mem = infer_state.cache_manager.alloc(batch_size) infer_state.decode_mem_index = alloc_mem # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) + position_ids = position_ids.repeat(batch_size, 1) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -134,6 +130,7 @@ def llama_model_forward( infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( position_ids.view(-1).shape[0], -1 ) + else: seq_len = infer_state.seq_len infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) @@ -145,7 +142,7 @@ def llama_model_forward( # embed positions if attention_mask is None: attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + (batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=inputs_embeds.device ) attention_mask = self._prepare_decoder_attention_mask( @@ -160,7 +157,6 @@ def llama_model_forward( next_decoder_cache = () if use_cache else None infer_state.decode_layer_id = 0 - for idx, decoder_layer in enumerate(self.layers): past_key_value = past_key_values[idx] if past_key_values is not None else None # NOTE: modify here for passing args to decoder layer @@ -184,7 +180,7 @@ def llama_model_forward( # update indices # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 if not return_dict: @@ -211,7 +207,6 @@ def llama_decoder_layer_forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, @@ -267,11 +262,8 @@ def llama_flash_attn_kvcache_forward( # NOTE might want to revise # need some way to record the length of past key values cache # since we won't return past_key_value_cache right now - if infer_state.decode_layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_len # seq_len cos, sin = infer_state.position_cos, infer_state.position_sin - # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) @@ -282,7 +274,6 @@ def llama_flash_attn_kvcache_forward( if infer_state.is_context_stage: # first token generation - # copy key and value calculated in current step to memory manager copy_kv_to_mem_cache( infer_state.decode_layer_id, @@ -291,9 +282,7 @@ def llama_flash_attn_kvcache_forward( infer_state.context_mem_index, infer_state.cache_manager, ) - attn_output = torch.empty_like(query_states) - llama_context_attn_fwd( query_states, key_states, @@ -301,7 +290,7 @@ def llama_flash_attn_kvcache_forward( attn_output, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) else: if infer_state.decode_is_contiguous: @@ -338,7 +327,7 @@ def llama_flash_attn_kvcache_forward( infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) attn_output = attn_output.view(bsz, q_len, self.hidden_size) diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 9830691581c0..070ebe45f659 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -2,7 +2,6 @@ import triton HAS_TRITON = True - except ImportError: HAS_TRITON = False print("Triton is not installed. Please install Triton to use Triton kernels.") diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py index 02edcc9a903a..0520bc111384 100644 --- a/colossalai/kernel/triton/copy_kv_cache_dest.py +++ b/colossalai/kernel/triton/copy_kv_cache_dest.py @@ -51,7 +51,6 @@ def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out" num_warps = 2 - _fwd_copy_kv_cache_dest[(seq_len,)]( k_ptr, dest_index_ptr, diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index bc229b17e08c..1d1e154b6e70 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -27,8 +27,10 @@ def data_gen(): # tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') # ----------------------------------- - input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long() - attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long() + input_ids = torch.Tensor( + [[1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082]] + ).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long() return dict(input_ids=input_ids, attention_mask=attention_mask) # label is needed for casual lm diff --git a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py new file mode 100644 index 000000000000..124f1f478b00 --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py @@ -0,0 +1,94 @@ +import pytest +from transformers import LlamaForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig + +import colossalai +from colossalai.inference.dynamic_batching.io_struct import Req +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.manager import DynamicBatchManager +from colossalai.inference.tensor_parallel import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +TP_SIZE = 1 +BATCH_SIZE = 2 +MAX_INPUT_LEN = 5 +MAX_OUTPUT_LEN = 16 + + +def run(): + sampling_params = SamplingParams() + + req1 = Req(0, [1], sampling_params) + req2 = Req(1, [2], sampling_params) + req3 = Req(2, [3], sampling_params) + # req 1-3 are initiliazed as token forward requests + req4 = Req(3, [10, 10, 10, 9, 1], sampling_params) + waiting_list = [] + waiting_list.append(req1) + waiting_list.append(req2) + waiting_list.append(req3) + + # init model and tp engine + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) + model = LlamaForCausalLM(llama_config) + model = model.half() + + shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) + infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + + dynamic_batch_manager = DynamicBatchManager( + tp_engine=infer_engine, + max_total_token_num=42, + batch_max_tokens=42, + eos_id=0, + log_stats=False, + log_stats_interval=10, + waiting_req_list=waiting_list, + ) + before_add = len(dynamic_batch_manager.req_queue) + + # test add req function + dynamic_batch_manager.add_req(req4.prompt_ids, req4.sample_params, req4.request_id) + assert len(dynamic_batch_manager.req_queue.waiting_req_list) == before_add + 1 + + # test abort function + dynamic_batch_manager.abort(req4.request_id) + assert dynamic_batch_manager.req_queue.waiting_req_list[-1].aborted == True + + # test filter batch function, loop_for_fwd, _step, _init_batch and _prefill/_decode batch are tested + batch = dynamic_batch_manager.req_queue.generate_new_batch() + assert len(batch) == 2 + + dynamic_batch_manager._init_batch(batch) + assert dynamic_batch_manager.engine.cache[batch.batch_id] is not None + + batch.reqs[0].has_generate_finished = True + # filter one finished + batch.filter_finished() + dynamic_batch_manager._filter_batch(batch) + assert len(dynamic_batch_manager.engine.cache) == 1 + + # test merge batch + new_batch = dynamic_batch_manager.req_queue.generate_new_batch(batch) + assert len(new_batch) == 1 + dynamic_batch_manager._init_batch(new_batch) + dynamic_batch_manager._merge_batch(batch, new_batch) + + assert len(dynamic_batch_manager.engine.cache[batch.batch_id]) == 2 + + +def check_dynamic_batching_manager(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_dynamic_batching_manager(): + spawn(check_dynamic_batching_manager, 1) + + +if __name__ == "__main__": + test_dynamic_batching_manager() diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py new file mode 100644 index 000000000000..63df491e5b52 --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -0,0 +1,70 @@ +import pytest +import torch +from packaging import version +from transformers import LlamaForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig + +import colossalai +from dataclasses import dataclass +from colossalai.inference.dynamic_batching.io_struct import Req +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.manager import start_dynamic_batching +from colossalai.inference.tensor_parallel import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +TP_SIZE = 1 +MAX_BATCH_SIZE = 2 +MAX_INPUT_LEN = 5 +MAX_OUTPUT_LEN = 16 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + +@dataclass +class args: + max_total_token_num: int + batch_max_tokens: int + eos_id: int + disable_log_stats: bool + log_stats_interval: int + + +def run(): + arg = args(max_total_token_num=42, batch_max_tokens=42, eos_id=0, disable_log_stats=False, log_stats_interval=10) + sampling_params = SamplingParams() + + req1 = Req(0, [0, 0, 10, 6, 8], sampling_params) + req2 = Req(1, [10, 10, 10, 10, 10], sampling_params) + req3 = Req(2, [0, 0, 10, 10, 10], sampling_params) + req4 = Req(3, [0, 0, 10, 10, 10], sampling_params) + + waiting_list = [] + waiting_list.append(req1) + waiting_list.append(req2) + waiting_list.append(req3) + waiting_list.append(req4) + + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) + model = LlamaForCausalLM(llama_config) + model = model.half() + + shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) + + infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) + + +def check_dynamic_forward(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run() + + +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_dynamic_batching(): + spawn(check_dynamic_forward, TP_SIZE) + + +if __name__ == "__main__": + test_dynamic_batching() diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 13bdf03996b9..b424525a3719 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -38,7 +38,6 @@ def run_llama_test(test_config): enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True ) infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - init_to_get_rotary(model.model, base=10000) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) input_tokens = { From fced14025043e29ce816b315f440601188f7f79f Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Thu, 12 Oct 2023 18:48:27 +0800 Subject: [PATCH 02/32] [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 --- .../inference/dynamic_batching/io_struct.py | 8 +- colossalai/inference/manager.py | 120 ++++++++++++++---- colossalai/inference/test_async.py | 33 +++++ .../test_dynamic_batching/test_forward.py | 10 +- 4 files changed, 139 insertions(+), 32 deletions(-) create mode 100644 colossalai/inference/test_async.py diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 2b2739f0ae90..44ad2964a39f 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -102,17 +102,21 @@ def mark_finished_req(self, eos_id): has_new_finish = True return has_new_finish - def filter_finished(self): + def filter_finished(self)->List[Req]: """ Filter finished requests from the batch, the finished ones will be removed from 'reqs'. """ # TODO: the logic of return should be defined here. unfinished_req = [] + finished_req = [] for req in self.reqs: if not req.has_generate_finished: - unfinished_req.append(req) + unfinished_req.append(req) + else: + finished_req.append(req) self.reqs = unfinished_req self.id_to_reqs = {req.request_id: req for req in self.reqs} + return finished_req def is_clear(self): return len(self.reqs) == 0 diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 72f77406789f..453570c7ec3e 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,5 +1,6 @@ import time from typing import List +import asyncio from .dynamic_batching.infer_batch import InferBatch from .dynamic_batching.io_struct import Batch, Req @@ -8,6 +9,8 @@ from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine +from transformers import AutoTokenizer +_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" class DynamicBatchManager: def __init__( @@ -54,6 +57,20 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques self.req_queue.append(req) return + def add_input(self, request_id, sampling_params, input_ids): + """ + Encode and Add new input to req queue. support one sequence input for now. + """ + prompt_ids = self.tokenizer.encode(input_ids) + prompt_len = len(prompt_ids) + if prompt_len > self.engine.max_input_len: + raise ValueError( + f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}" + ) + sampling_params.stop_sentences_to_token_ids(self.tokenizer) + self.add_req(prompt_ids, sampling_params, request_id) + return + def abort(self, request_id): if self.running_batch is not None: for req in self.running_batch.reqs: @@ -66,13 +83,15 @@ def abort(self, request_id): req.aborted = True return - def loop_for_fwd(self): + async def loop_for_fwd(self): """ The main loop for a dynamic batching process. """ counter_count = 0 - while self.running_batch is not None or self.req_queue.waiting_req_list: - self._step() + #self.running_batch is not None or self.req_queue.waiting_req_list + while True: + async for item in self._step(): + yield item counter_count += 1 if self.running_batch is not None: if counter_count % self.mem_usage_interval == 0: @@ -87,6 +106,26 @@ def loop_for_fwd(self): if self.running_batch is None: time.sleep(0.1) # 10ms + def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): + if tokenizer is not None: + self.tokenizer = tokenizer + else: + if "llama" in tokenizer_name.lower() and use_fast == True: + print( + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai.") + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + except TypeError as e: + use_fast = False + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + + def _step(self): """ Logic for handling requests @@ -97,14 +136,14 @@ def _step(self): if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch - self._prefill_batch(self.running_batch) + yield from self._prefill_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) - self._decode_batch(self.running_batch) + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 return @@ -112,17 +151,18 @@ def _step(self): new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) - self._prefill_batch(new_mini_batch) + yield from self._prefill_batch(new_mini_batch) if not new_mini_batch.is_clear(): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 + else: self.stats_tool.count_output_tokens(self.running_batch) - self._decode_batch(self.running_batch) + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 - + return def _init_batch(self, batch: Batch, dtype="fp16"): @@ -158,7 +198,8 @@ def _prefill_batch(self, batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - self._handle_finish_req(batch, has_new_finished_req) + yield from self._handle_finish_req(batch, has_new_finished_req) + # delete finished reqs def _decode_batch(self, batch: Batch): @@ -169,7 +210,7 @@ def _decode_batch(self, batch: Batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - self._handle_finish_req(batch, has_new_finished_req) + yield from self._handle_finish_req(batch, has_new_finished_req) def _filter_batch(self, batch: Batch): batch_id = batch.batch_id @@ -201,11 +242,13 @@ def _remove_batch(self, batch): def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: - batch.filter_finished() + finished_reqs=batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) + yield from self._output_process(finished_reqs) + def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): @@ -218,26 +261,47 @@ def _add_token_id_to_req(self, batch: Batch, req_ans): req.output_metadata_list.append(new_gen_metadata) return + async def _output_process(self, finished_reqs: List[Req]): + """ + Process the output of a batch. + """ + for req in finished_reqs: + output = self.tokenizer.decode(req.output_ids) + yield output, req.request_id, req.output_metadata_list + def clean_up(self): # this logic should be implemented in the future. pass + async def generate(self,request_id,prompt_id,sampling_params): + """ + Generate the output of a request. + """ + self.add_input(request_id,prompt_id,sampling_params) + def start_dynamic_batching(args, tp_engine, waiting_req_list): - # try: - batch_manager = DynamicBatchManager( - tp_engine=tp_engine, - max_total_token_num=args.max_total_token_num, - batch_max_tokens=args.batch_max_tokens, - eos_id=args.eos_id, - log_stats=not args.disable_log_stats, - log_stats_interval=args.log_stats_interval, - waiting_req_list=waiting_req_list, - ) - - # except Exception: - # batch_manager.clean_up() - # raise - - batch_manager.loop_for_fwd() - return + try: + batch_manager = DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + except Exception: + batch_manager.clean_up() + raise + + batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__) + prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world")) + + asyncio.run(prod_task) + + for item in batch_manager.loop_for_fwd(): + print(item) + + return batch_manager diff --git a/colossalai/inference/test_async.py b/colossalai/inference/test_async.py new file mode 100644 index 000000000000..08720f36da22 --- /dev/null +++ b/colossalai/inference/test_async.py @@ -0,0 +1,33 @@ +import asyncio + +shared_list = [] + +async def producer(): + for i in range(5): + await asyncio.sleep(1) # 模拟异步获取数据的操作 + shared_list.append(i) + print(f"Produced {i}") + +async def consumer(): + last_index = 0 + while True: + await asyncio.sleep(0.5) # 为了不使循环过于紧凑,增加了小的延迟 + if last_index < len(shared_list): + item = shared_list[last_index] + print(f"Consumed {item}") + yield item + last_index += 1 + +async def main(): + # 创建生产者和消费者任务 + prod_task = asyncio.create_task(producer()) + + # 等待生产者任务完成 + await prod_task + + async for data in consumer(): + print(data) + # 为了示例的目的,我们只等待一段时间,然后停止消费者 + await asyncio.sleep(5) + +asyncio.run(main()) diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py index 63df491e5b52..ca6401259831 100644 --- a/tests/test_infer/test_dynamic_batching/test_forward.py +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -42,7 +42,7 @@ def run(): waiting_list.append(req2) waiting_list.append(req3) waiting_list.append(req4) - + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) model = LlamaForCausalLM(llama_config) model = model.half() @@ -50,7 +50,13 @@ def run(): shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) + manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) + manager._set_tokenizer(tokenizer_name = model.__class__.__name__) + result_generator = manager.loop_for_fwd() + for result in result_generator: + print(result) + + def check_dynamic_forward(rank, world_size, port): From fbf3c09e673794ed18c91d4bab1a7dfea052e95a Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Fri, 13 Oct 2023 11:01:18 +0800 Subject: [PATCH 03/32] [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> --- .../inference/dynamic_batching/io_struct.py | 15 +- colossalai/inference/manager.py | 139 ++++++++++-------- colossalai/inference/test_async.py | 33 ----- .../test_dynamic_batching/test_forward.py | 29 +++- 4 files changed, 107 insertions(+), 109 deletions(-) delete mode 100644 colossalai/inference/test_async.py diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 44ad2964a39f..2028e320baee 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -4,7 +4,7 @@ class Req: - def __init__(self, request_id, prompt_ids, sample_params: SamplingParams): + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""): self.request_id = request_id self.prompt_ids = prompt_ids self.input_len = len(prompt_ids) @@ -14,6 +14,7 @@ def __init__(self, request_id, prompt_ids, sample_params: SamplingParams): self.output_metadata_list = [] self.has_generate_finished = False self.aborted = False + self.prompts = prompts def to_rpc_obj(self): return { @@ -36,7 +37,11 @@ def stop_sequences_matched(self): if self.sample_params.stop_sequences is not None: for stop_token_ids in self.sample_params.stop_sequences: stop_len = len(stop_token_ids) - if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)): + if ( + stop_len > 0 + and len(self.output_ids) >= stop_len + and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)) + ): return True return False @@ -102,7 +107,7 @@ def mark_finished_req(self, eos_id): has_new_finish = True return has_new_finish - def filter_finished(self)->List[Req]: + def filter_finished(self) -> List[Req]: """ Filter finished requests from the batch, the finished ones will be removed from 'reqs'. """ @@ -111,9 +116,9 @@ def filter_finished(self)->List[Req]: finished_req = [] for req in self.reqs: if not req.has_generate_finished: - unfinished_req.append(req) + unfinished_req.append(req) else: - finished_req.append(req) + finished_req.append(req) self.reqs = unfinished_req self.id_to_reqs = {req.request_id: req for req in self.reqs} return finished_req diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 453570c7ec3e..61276660df07 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,6 +1,7 @@ -import time -from typing import List import asyncio +from typing import List + +from transformers import AutoTokenizer from .dynamic_batching.infer_batch import InferBatch from .dynamic_batching.io_struct import Batch, Req @@ -9,9 +10,9 @@ from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine -from transformers import AutoTokenizer _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" + class DynamicBatchManager: def __init__( self, @@ -19,6 +20,7 @@ def __init__( max_total_token_num, batch_max_tokens, eos_id, + model, log_stats=True, log_stats_interval=10, running_batch: Batch = None, @@ -30,6 +32,7 @@ def __init__( batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine eos_id : The end token of a seq + model: the model weight dir path, the app will load config, weights and tokenizer from this dir log_stats : whether to log stats log_stats_interval : log stats interval running_batch : running batch @@ -45,32 +48,32 @@ def __init__( self.eos_id = eos_id self.has_wait_tokens = 0 self.max_wait_tokens = 10 - + self.model = model + self.stats_tool = Stats(log_stats, log_stats_interval) self.mem_usage_interval = log_stats_interval * 2 + self._set_tokenizer(tokenizer_name=self.model) - def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str): + async def add_req(self, request_id, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""): """ Add new request to req queue, during initialization all requests are held in waiting list. """ - req = Req(request_id, prompt_ids, sampling_params) + req = Req(request_id, prompt_ids, sampling_params, prompts) self.req_queue.append(req) return - def add_input(self, request_id, sampling_params, input_ids): + async def add_input(self, request_id, sampling_params, prompts): """ Encode and Add new input to req queue. support one sequence input for now. """ - prompt_ids = self.tokenizer.encode(input_ids) + prompt_ids = self.tokenizer.encode(prompts) prompt_len = len(prompt_ids) if prompt_len > self.engine.max_input_len: - raise ValueError( - f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}" - ) + raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}") sampling_params.stop_sentences_to_token_ids(self.tokenizer) - self.add_req(prompt_ids, sampling_params, request_id) + self.add_req(request_id, prompt_ids, sampling_params, prompts) return - + def abort(self, request_id): if self.running_batch is not None: for req in self.running_batch.reqs: @@ -88,10 +91,15 @@ async def loop_for_fwd(self): The main loop for a dynamic batching process. """ counter_count = 0 - #self.running_batch is not None or self.req_queue.waiting_req_list + # self.running_batch is not None or self.req_queue.waiting_req_list while True: - async for item in self._step(): - yield item + if self.running_batch is not None or self.req_queue.waiting_req_list: + async for result in self._step(): + yield result + else: + # need to wait for new requests + await asyncio.sleep(0.1) + continue counter_count += 1 if self.running_batch is not None: if counter_count % self.mem_usage_interval == 0: @@ -103,30 +111,33 @@ async def loop_for_fwd(self): ) self.stats_tool.print_stats() - if self.running_batch is None: - time.sleep(0.1) # 10ms - - def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): + def _set_tokenizer( + self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast: bool = True + ): if tokenizer is not None: - self.tokenizer = tokenizer + self.tokenizer = tokenizer else: if "llama" in tokenizer_name.lower() and use_fast == True: print( - "For some LLaMA-based models, initializing the fast tokenizer may " - "take a long time. To eliminate the initialization time, consider " - f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " - "tokenizer. This is done automatically in Colossalai.") - - tokenizer_name = _FAST_LLAMA_TOKENIZER - - try: - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - except TypeError as e: + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai." + ) + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) + except TypeError: use_fast = False - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) - def _step(self): + async def _step(self): """ Logic for handling requests """ @@ -136,14 +147,15 @@ def _step(self): if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch - yield from self._prefill_batch(self.running_batch) + async for item in self._prefill_batch(self.running_batch): + yield item self._filter_runing_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) - yield from self._decode_batch(self.running_batch) + self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 return @@ -151,18 +163,20 @@ def _step(self): new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) - yield from self._prefill_batch(new_mini_batch) + async for item in self._prefill_batch(new_mini_batch): + yield item if not new_mini_batch.is_clear(): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 - + else: self.stats_tool.count_output_tokens(self.running_batch) - yield from self._decode_batch(self.running_batch) + async for item in self._decode_batch(self.running_batch): + yield item self._filter_runing_batch() self.has_wait_tokens += 1 - + return def _init_batch(self, batch: Batch, dtype="fp16"): @@ -187,7 +201,7 @@ def _init_batch(self, batch: Batch, dtype="fp16"): ) self.engine.cache[batch_id] = batch_data - def _prefill_batch(self, batch): + async def _prefill_batch(self, batch): """ For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. """ @@ -198,11 +212,11 @@ def _prefill_batch(self, batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - yield from self._handle_finish_req(batch, has_new_finished_req) - + async for item in self._handle_finish_req(batch, has_new_finished_req): + yield item # delete finished reqs - def _decode_batch(self, batch: Batch): + async def _decode_batch(self, batch: Batch): """ Decoding process """ @@ -210,7 +224,8 @@ def _decode_batch(self, batch: Batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - yield from self._handle_finish_req(batch, has_new_finished_req) + async for item in self._handle_finish_req(batch, has_new_finished_req): + yield item def _filter_batch(self, batch: Batch): batch_id = batch.batch_id @@ -240,15 +255,15 @@ def _remove_batch(self, batch): batch.free_self() del batch - def _handle_finish_req(self, batch: Batch, has_new_finished_req): + async def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: - finished_reqs=batch.filter_finished() + finished_reqs = batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) - yield from self._output_process(finished_reqs) - + async for item in self._output_process(finished_reqs): + yield item def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): @@ -267,18 +282,24 @@ async def _output_process(self, finished_reqs: List[Req]): """ for req in finished_reqs: output = self.tokenizer.decode(req.output_ids) - yield output, req.request_id, req.output_metadata_list + yield req.prompts + output def clean_up(self): # this logic should be implemented in the future. pass - async def generate(self,request_id,prompt_id,sampling_params): + async def generate(self, request_id, prompt_id, sampling_params): """ Generate the output of a request. """ - self.add_input(request_id,prompt_id,sampling_params) - + + await self.add_input(request_id, prompt_id, sampling_params) + + +async def process_data(dbm): + async for data in dbm.loop_for_fwd(): + print(data) + def start_dynamic_batching(args, tp_engine, waiting_req_list): try: @@ -287,21 +308,13 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list): max_total_token_num=args.max_total_token_num, batch_max_tokens=args.batch_max_tokens, eos_id=args.eos_id, + model=args.model, log_stats=not args.disable_log_stats, log_stats_interval=args.log_stats_interval, waiting_req_list=waiting_req_list, ) except Exception: - batch_manager.clean_up() - raise - - batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__) - prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world")) - - asyncio.run(prod_task) - - for item in batch_manager.loop_for_fwd(): - print(item) + raise RuntimeError("Failed to start dynamic batching") return batch_manager diff --git a/colossalai/inference/test_async.py b/colossalai/inference/test_async.py deleted file mode 100644 index 08720f36da22..000000000000 --- a/colossalai/inference/test_async.py +++ /dev/null @@ -1,33 +0,0 @@ -import asyncio - -shared_list = [] - -async def producer(): - for i in range(5): - await asyncio.sleep(1) # 模拟异步获取数据的操作 - shared_list.append(i) - print(f"Produced {i}") - -async def consumer(): - last_index = 0 - while True: - await asyncio.sleep(0.5) # 为了不使循环过于紧凑,增加了小的延迟 - if last_index < len(shared_list): - item = shared_list[last_index] - print(f"Consumed {item}") - yield item - last_index += 1 - -async def main(): - # 创建生产者和消费者任务 - prod_task = asyncio.create_task(producer()) - - # 等待生产者任务完成 - await prod_task - - async for data in consumer(): - print(data) - # 为了示例的目的,我们只等待一段时间,然后停止消费者 - await asyncio.sleep(5) - -asyncio.run(main()) diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py index ca6401259831..1b42e3a1094f 100644 --- a/tests/test_infer/test_dynamic_batching/test_forward.py +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -1,3 +1,6 @@ +import asyncio +from dataclasses import dataclass + import pytest import torch from packaging import version @@ -5,10 +8,9 @@ from transformers.models.llama.configuration_llama import LlamaConfig import colossalai -from dataclasses import dataclass from colossalai.inference.dynamic_batching.io_struct import Req from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -from colossalai.inference.manager import start_dynamic_batching +from colossalai.inference.manager import process_data, start_dynamic_batching from colossalai.inference.tensor_parallel import TPInferEngine from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn @@ -19,17 +21,26 @@ MAX_OUTPUT_LEN = 16 CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + @dataclass class args: max_total_token_num: int batch_max_tokens: int eos_id: int + model: str disable_log_stats: bool log_stats_interval: int def run(): - arg = args(max_total_token_num=42, batch_max_tokens=42, eos_id=0, disable_log_stats=False, log_stats_interval=10) + arg = args( + max_total_token_num=42, + batch_max_tokens=42, + eos_id=0, + model="llama", + disable_log_stats=False, + log_stats_interval=10, + ) sampling_params = SamplingParams() req1 = Req(0, [0, 0, 10, 6, 8], sampling_params) @@ -42,7 +53,7 @@ def run(): waiting_list.append(req2) waiting_list.append(req3) waiting_list.append(req4) - + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) model = LlamaForCausalLM(llama_config) model = model.half() @@ -51,12 +62,14 @@ def run(): infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) - manager._set_tokenizer(tokenizer_name = model.__class__.__name__) - result_generator = manager.loop_for_fwd() - for result in result_generator: - print(result) + asyncio.run(test(manager)) +async def test(manager): + asyncio.create_task(process_data(manager)) + await asyncio.sleep(5) + await manager.add_req(4, [0, 0, 10, 10, 10], SamplingParams()) + await asyncio.sleep(5) def check_dynamic_forward(rank, world_size, port): From d509e796cd4181bf8f276bc4c84580e2fc8f4ef9 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Fri, 13 Oct 2023 16:04:52 +0800 Subject: [PATCH 04/32] Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commit fbf3c09e673794ed18c91d4bab1a7dfea052e95a. --- .../inference/dynamic_batching/io_struct.py | 15 +- colossalai/inference/manager.py | 139 ++++++++---------- colossalai/inference/test_async.py | 33 +++++ .../test_dynamic_batching/test_forward.py | 29 +--- 4 files changed, 109 insertions(+), 107 deletions(-) create mode 100644 colossalai/inference/test_async.py diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 2028e320baee..44ad2964a39f 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -4,7 +4,7 @@ class Req: - def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""): + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams): self.request_id = request_id self.prompt_ids = prompt_ids self.input_len = len(prompt_ids) @@ -14,7 +14,6 @@ def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompt self.output_metadata_list = [] self.has_generate_finished = False self.aborted = False - self.prompts = prompts def to_rpc_obj(self): return { @@ -37,11 +36,7 @@ def stop_sequences_matched(self): if self.sample_params.stop_sequences is not None: for stop_token_ids in self.sample_params.stop_sequences: stop_len = len(stop_token_ids) - if ( - stop_len > 0 - and len(self.output_ids) >= stop_len - and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)) - ): + if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)): return True return False @@ -107,7 +102,7 @@ def mark_finished_req(self, eos_id): has_new_finish = True return has_new_finish - def filter_finished(self) -> List[Req]: + def filter_finished(self)->List[Req]: """ Filter finished requests from the batch, the finished ones will be removed from 'reqs'. """ @@ -116,9 +111,9 @@ def filter_finished(self) -> List[Req]: finished_req = [] for req in self.reqs: if not req.has_generate_finished: - unfinished_req.append(req) + unfinished_req.append(req) else: - finished_req.append(req) + finished_req.append(req) self.reqs = unfinished_req self.id_to_reqs = {req.request_id: req for req in self.reqs} return finished_req diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 61276660df07..453570c7ec3e 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,7 +1,6 @@ -import asyncio +import time from typing import List - -from transformers import AutoTokenizer +import asyncio from .dynamic_batching.infer_batch import InferBatch from .dynamic_batching.io_struct import Batch, Req @@ -10,9 +9,9 @@ from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine +from transformers import AutoTokenizer _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" - class DynamicBatchManager: def __init__( self, @@ -20,7 +19,6 @@ def __init__( max_total_token_num, batch_max_tokens, eos_id, - model, log_stats=True, log_stats_interval=10, running_batch: Batch = None, @@ -32,7 +30,6 @@ def __init__( batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine eos_id : The end token of a seq - model: the model weight dir path, the app will load config, weights and tokenizer from this dir log_stats : whether to log stats log_stats_interval : log stats interval running_batch : running batch @@ -48,32 +45,32 @@ def __init__( self.eos_id = eos_id self.has_wait_tokens = 0 self.max_wait_tokens = 10 - self.model = model - + self.stats_tool = Stats(log_stats, log_stats_interval) self.mem_usage_interval = log_stats_interval * 2 - self._set_tokenizer(tokenizer_name=self.model) - async def add_req(self, request_id, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""): + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str): """ Add new request to req queue, during initialization all requests are held in waiting list. """ - req = Req(request_id, prompt_ids, sampling_params, prompts) + req = Req(request_id, prompt_ids, sampling_params) self.req_queue.append(req) return - async def add_input(self, request_id, sampling_params, prompts): + def add_input(self, request_id, sampling_params, input_ids): """ Encode and Add new input to req queue. support one sequence input for now. """ - prompt_ids = self.tokenizer.encode(prompts) + prompt_ids = self.tokenizer.encode(input_ids) prompt_len = len(prompt_ids) if prompt_len > self.engine.max_input_len: - raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}") + raise ValueError( + f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}" + ) sampling_params.stop_sentences_to_token_ids(self.tokenizer) - self.add_req(request_id, prompt_ids, sampling_params, prompts) + self.add_req(prompt_ids, sampling_params, request_id) return - + def abort(self, request_id): if self.running_batch is not None: for req in self.running_batch.reqs: @@ -91,15 +88,10 @@ async def loop_for_fwd(self): The main loop for a dynamic batching process. """ counter_count = 0 - # self.running_batch is not None or self.req_queue.waiting_req_list + #self.running_batch is not None or self.req_queue.waiting_req_list while True: - if self.running_batch is not None or self.req_queue.waiting_req_list: - async for result in self._step(): - yield result - else: - # need to wait for new requests - await asyncio.sleep(0.1) - continue + async for item in self._step(): + yield item counter_count += 1 if self.running_batch is not None: if counter_count % self.mem_usage_interval == 0: @@ -111,33 +103,30 @@ async def loop_for_fwd(self): ) self.stats_tool.print_stats() - def _set_tokenizer( - self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast: bool = True - ): + if self.running_batch is None: + time.sleep(0.1) # 10ms + + def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): if tokenizer is not None: - self.tokenizer = tokenizer + self.tokenizer = tokenizer else: if "llama" in tokenizer_name.lower() and use_fast == True: print( - "For some LLaMA-based models, initializing the fast tokenizer may " - "take a long time. To eliminate the initialization time, consider " - f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " - "tokenizer. This is done automatically in Colossalai." - ) - - tokenizer_name = _FAST_LLAMA_TOKENIZER - - try: - self.tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code - ) - except TypeError: + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai.") + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + except TypeError as e: use_fast = False - self.tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code - ) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + - async def _step(self): + def _step(self): """ Logic for handling requests """ @@ -147,15 +136,14 @@ async def _step(self): if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch - async for item in self._prefill_batch(self.running_batch): - yield item + yield from self._prefill_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) - self._decode_batch(self.running_batch) + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 return @@ -163,20 +151,18 @@ async def _step(self): new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) - async for item in self._prefill_batch(new_mini_batch): - yield item + yield from self._prefill_batch(new_mini_batch) if not new_mini_batch.is_clear(): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 - + else: self.stats_tool.count_output_tokens(self.running_batch) - async for item in self._decode_batch(self.running_batch): - yield item + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 - + return def _init_batch(self, batch: Batch, dtype="fp16"): @@ -201,7 +187,7 @@ def _init_batch(self, batch: Batch, dtype="fp16"): ) self.engine.cache[batch_id] = batch_data - async def _prefill_batch(self, batch): + def _prefill_batch(self, batch): """ For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. """ @@ -212,11 +198,11 @@ async def _prefill_batch(self, batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - async for item in self._handle_finish_req(batch, has_new_finished_req): - yield item + yield from self._handle_finish_req(batch, has_new_finished_req) + # delete finished reqs - async def _decode_batch(self, batch: Batch): + def _decode_batch(self, batch: Batch): """ Decoding process """ @@ -224,8 +210,7 @@ async def _decode_batch(self, batch: Batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - async for item in self._handle_finish_req(batch, has_new_finished_req): - yield item + yield from self._handle_finish_req(batch, has_new_finished_req) def _filter_batch(self, batch: Batch): batch_id = batch.batch_id @@ -255,15 +240,15 @@ def _remove_batch(self, batch): batch.free_self() del batch - async def _handle_finish_req(self, batch: Batch, has_new_finished_req): + def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: - finished_reqs = batch.filter_finished() + finished_reqs=batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) - async for item in self._output_process(finished_reqs): - yield item + yield from self._output_process(finished_reqs) + def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): @@ -282,24 +267,18 @@ async def _output_process(self, finished_reqs: List[Req]): """ for req in finished_reqs: output = self.tokenizer.decode(req.output_ids) - yield req.prompts + output + yield output, req.request_id, req.output_metadata_list def clean_up(self): # this logic should be implemented in the future. pass - async def generate(self, request_id, prompt_id, sampling_params): + async def generate(self,request_id,prompt_id,sampling_params): """ Generate the output of a request. """ - - await self.add_input(request_id, prompt_id, sampling_params) - - -async def process_data(dbm): - async for data in dbm.loop_for_fwd(): - print(data) - + self.add_input(request_id,prompt_id,sampling_params) + def start_dynamic_batching(args, tp_engine, waiting_req_list): try: @@ -308,13 +287,21 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list): max_total_token_num=args.max_total_token_num, batch_max_tokens=args.batch_max_tokens, eos_id=args.eos_id, - model=args.model, log_stats=not args.disable_log_stats, log_stats_interval=args.log_stats_interval, waiting_req_list=waiting_req_list, ) except Exception: - raise RuntimeError("Failed to start dynamic batching") + batch_manager.clean_up() + raise + + batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__) + prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world")) + + asyncio.run(prod_task) + + for item in batch_manager.loop_for_fwd(): + print(item) return batch_manager diff --git a/colossalai/inference/test_async.py b/colossalai/inference/test_async.py new file mode 100644 index 000000000000..08720f36da22 --- /dev/null +++ b/colossalai/inference/test_async.py @@ -0,0 +1,33 @@ +import asyncio + +shared_list = [] + +async def producer(): + for i in range(5): + await asyncio.sleep(1) # 模拟异步获取数据的操作 + shared_list.append(i) + print(f"Produced {i}") + +async def consumer(): + last_index = 0 + while True: + await asyncio.sleep(0.5) # 为了不使循环过于紧凑,增加了小的延迟 + if last_index < len(shared_list): + item = shared_list[last_index] + print(f"Consumed {item}") + yield item + last_index += 1 + +async def main(): + # 创建生产者和消费者任务 + prod_task = asyncio.create_task(producer()) + + # 等待生产者任务完成 + await prod_task + + async for data in consumer(): + print(data) + # 为了示例的目的,我们只等待一段时间,然后停止消费者 + await asyncio.sleep(5) + +asyncio.run(main()) diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py index 1b42e3a1094f..ca6401259831 100644 --- a/tests/test_infer/test_dynamic_batching/test_forward.py +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -1,6 +1,3 @@ -import asyncio -from dataclasses import dataclass - import pytest import torch from packaging import version @@ -8,9 +5,10 @@ from transformers.models.llama.configuration_llama import LlamaConfig import colossalai +from dataclasses import dataclass from colossalai.inference.dynamic_batching.io_struct import Req from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -from colossalai.inference.manager import process_data, start_dynamic_batching +from colossalai.inference.manager import start_dynamic_batching from colossalai.inference.tensor_parallel import TPInferEngine from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn @@ -21,26 +19,17 @@ MAX_OUTPUT_LEN = 16 CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") - @dataclass class args: max_total_token_num: int batch_max_tokens: int eos_id: int - model: str disable_log_stats: bool log_stats_interval: int def run(): - arg = args( - max_total_token_num=42, - batch_max_tokens=42, - eos_id=0, - model="llama", - disable_log_stats=False, - log_stats_interval=10, - ) + arg = args(max_total_token_num=42, batch_max_tokens=42, eos_id=0, disable_log_stats=False, log_stats_interval=10) sampling_params = SamplingParams() req1 = Req(0, [0, 0, 10, 6, 8], sampling_params) @@ -53,7 +42,7 @@ def run(): waiting_list.append(req2) waiting_list.append(req3) waiting_list.append(req4) - + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) model = LlamaForCausalLM(llama_config) model = model.half() @@ -62,14 +51,12 @@ def run(): infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) - asyncio.run(test(manager)) + manager._set_tokenizer(tokenizer_name = model.__class__.__name__) + result_generator = manager.loop_for_fwd() + for result in result_generator: + print(result) -async def test(manager): - asyncio.create_task(process_data(manager)) - await asyncio.sleep(5) - await manager.add_req(4, [0, 0, 10, 10, 10], SamplingParams()) - await asyncio.sleep(5) def check_dynamic_forward(rank, world_size, port): From ec004fe90cafc89205eee8f849096228c6825c81 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Sat, 14 Oct 2023 12:35:03 +0800 Subject: [PATCH 05/32] Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. --- .../inference/dynamic_batching/io_struct.py | 8 +- colossalai/inference/manager.py | 120 ++++-------------- colossalai/inference/test_async.py | 33 ----- .../test_dynamic_batching/test_forward.py | 10 +- 4 files changed, 32 insertions(+), 139 deletions(-) delete mode 100644 colossalai/inference/test_async.py diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 44ad2964a39f..2b2739f0ae90 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -102,21 +102,17 @@ def mark_finished_req(self, eos_id): has_new_finish = True return has_new_finish - def filter_finished(self)->List[Req]: + def filter_finished(self): """ Filter finished requests from the batch, the finished ones will be removed from 'reqs'. """ # TODO: the logic of return should be defined here. unfinished_req = [] - finished_req = [] for req in self.reqs: if not req.has_generate_finished: - unfinished_req.append(req) - else: - finished_req.append(req) + unfinished_req.append(req) self.reqs = unfinished_req self.id_to_reqs = {req.request_id: req for req in self.reqs} - return finished_req def is_clear(self): return len(self.reqs) == 0 diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 453570c7ec3e..72f77406789f 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,6 +1,5 @@ import time from typing import List -import asyncio from .dynamic_batching.infer_batch import InferBatch from .dynamic_batching.io_struct import Batch, Req @@ -9,8 +8,6 @@ from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine -from transformers import AutoTokenizer -_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" class DynamicBatchManager: def __init__( @@ -57,20 +54,6 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques self.req_queue.append(req) return - def add_input(self, request_id, sampling_params, input_ids): - """ - Encode and Add new input to req queue. support one sequence input for now. - """ - prompt_ids = self.tokenizer.encode(input_ids) - prompt_len = len(prompt_ids) - if prompt_len > self.engine.max_input_len: - raise ValueError( - f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}" - ) - sampling_params.stop_sentences_to_token_ids(self.tokenizer) - self.add_req(prompt_ids, sampling_params, request_id) - return - def abort(self, request_id): if self.running_batch is not None: for req in self.running_batch.reqs: @@ -83,15 +66,13 @@ def abort(self, request_id): req.aborted = True return - async def loop_for_fwd(self): + def loop_for_fwd(self): """ The main loop for a dynamic batching process. """ counter_count = 0 - #self.running_batch is not None or self.req_queue.waiting_req_list - while True: - async for item in self._step(): - yield item + while self.running_batch is not None or self.req_queue.waiting_req_list: + self._step() counter_count += 1 if self.running_batch is not None: if counter_count % self.mem_usage_interval == 0: @@ -106,26 +87,6 @@ async def loop_for_fwd(self): if self.running_batch is None: time.sleep(0.1) # 10ms - def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): - if tokenizer is not None: - self.tokenizer = tokenizer - else: - if "llama" in tokenizer_name.lower() and use_fast == True: - print( - "For some LLaMA-based models, initializing the fast tokenizer may " - "take a long time. To eliminate the initialization time, consider " - f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " - "tokenizer. This is done automatically in Colossalai.") - - tokenizer_name = _FAST_LLAMA_TOKENIZER - - try: - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - except TypeError as e: - use_fast = False - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - - def _step(self): """ Logic for handling requests @@ -136,14 +97,14 @@ def _step(self): if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch - yield from self._prefill_batch(self.running_batch) + self._prefill_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) - yield from self._decode_batch(self.running_batch) + self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 return @@ -151,18 +112,17 @@ def _step(self): new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) - yield from self._prefill_batch(new_mini_batch) + self._prefill_batch(new_mini_batch) if not new_mini_batch.is_clear(): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 - else: self.stats_tool.count_output_tokens(self.running_batch) - yield from self._decode_batch(self.running_batch) + self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 - + return def _init_batch(self, batch: Batch, dtype="fp16"): @@ -198,8 +158,7 @@ def _prefill_batch(self, batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - yield from self._handle_finish_req(batch, has_new_finished_req) - + self._handle_finish_req(batch, has_new_finished_req) # delete finished reqs def _decode_batch(self, batch: Batch): @@ -210,7 +169,7 @@ def _decode_batch(self, batch: Batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - yield from self._handle_finish_req(batch, has_new_finished_req) + self._handle_finish_req(batch, has_new_finished_req) def _filter_batch(self, batch: Batch): batch_id = batch.batch_id @@ -242,13 +201,11 @@ def _remove_batch(self, batch): def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: - finished_reqs=batch.filter_finished() + batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) - yield from self._output_process(finished_reqs) - def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): @@ -261,47 +218,26 @@ def _add_token_id_to_req(self, batch: Batch, req_ans): req.output_metadata_list.append(new_gen_metadata) return - async def _output_process(self, finished_reqs: List[Req]): - """ - Process the output of a batch. - """ - for req in finished_reqs: - output = self.tokenizer.decode(req.output_ids) - yield output, req.request_id, req.output_metadata_list - def clean_up(self): # this logic should be implemented in the future. pass - async def generate(self,request_id,prompt_id,sampling_params): - """ - Generate the output of a request. - """ - self.add_input(request_id,prompt_id,sampling_params) - def start_dynamic_batching(args, tp_engine, waiting_req_list): - try: - batch_manager = DynamicBatchManager( - tp_engine=tp_engine, - max_total_token_num=args.max_total_token_num, - batch_max_tokens=args.batch_max_tokens, - eos_id=args.eos_id, - log_stats=not args.disable_log_stats, - log_stats_interval=args.log_stats_interval, - waiting_req_list=waiting_req_list, - ) - - except Exception: - batch_manager.clean_up() - raise - - batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__) - prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world")) - - asyncio.run(prod_task) - - for item in batch_manager.loop_for_fwd(): - print(item) - - return batch_manager + # try: + batch_manager = DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + # except Exception: + # batch_manager.clean_up() + # raise + + batch_manager.loop_for_fwd() + return diff --git a/colossalai/inference/test_async.py b/colossalai/inference/test_async.py deleted file mode 100644 index 08720f36da22..000000000000 --- a/colossalai/inference/test_async.py +++ /dev/null @@ -1,33 +0,0 @@ -import asyncio - -shared_list = [] - -async def producer(): - for i in range(5): - await asyncio.sleep(1) # 模拟异步获取数据的操作 - shared_list.append(i) - print(f"Produced {i}") - -async def consumer(): - last_index = 0 - while True: - await asyncio.sleep(0.5) # 为了不使循环过于紧凑,增加了小的延迟 - if last_index < len(shared_list): - item = shared_list[last_index] - print(f"Consumed {item}") - yield item - last_index += 1 - -async def main(): - # 创建生产者和消费者任务 - prod_task = asyncio.create_task(producer()) - - # 等待生产者任务完成 - await prod_task - - async for data in consumer(): - print(data) - # 为了示例的目的,我们只等待一段时间,然后停止消费者 - await asyncio.sleep(5) - -asyncio.run(main()) diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py index ca6401259831..63df491e5b52 100644 --- a/tests/test_infer/test_dynamic_batching/test_forward.py +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -42,7 +42,7 @@ def run(): waiting_list.append(req2) waiting_list.append(req3) waiting_list.append(req4) - + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) model = LlamaForCausalLM(llama_config) model = model.half() @@ -50,13 +50,7 @@ def run(): shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) - manager._set_tokenizer(tokenizer_name = model.__class__.__name__) - result_generator = manager.loop_for_fwd() - for result in result_generator: - print(result) - - + start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) def check_dynamic_forward(rank, world_size, port): From 78cd937fb33e8ba887e7e19065c50190640351cc Mon Sep 17 00:00:00 2001 From: Cuiqing Li Date: Sat, 14 Oct 2023 13:01:04 +0800 Subject: [PATCH 06/32] Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commit fced14025043e29ce816b315f440601188f7f79f. --- .../inference/dynamic_batching/io_struct.py | 8 +- colossalai/inference/manager.py | 120 ++++-------------- colossalai/inference/test_async.py | 33 ----- .../test_dynamic_batching/test_forward.py | 10 +- 4 files changed, 32 insertions(+), 139 deletions(-) delete mode 100644 colossalai/inference/test_async.py diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 44ad2964a39f..2b2739f0ae90 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -102,21 +102,17 @@ def mark_finished_req(self, eos_id): has_new_finish = True return has_new_finish - def filter_finished(self)->List[Req]: + def filter_finished(self): """ Filter finished requests from the batch, the finished ones will be removed from 'reqs'. """ # TODO: the logic of return should be defined here. unfinished_req = [] - finished_req = [] for req in self.reqs: if not req.has_generate_finished: - unfinished_req.append(req) - else: - finished_req.append(req) + unfinished_req.append(req) self.reqs = unfinished_req self.id_to_reqs = {req.request_id: req for req in self.reqs} - return finished_req def is_clear(self): return len(self.reqs) == 0 diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 453570c7ec3e..72f77406789f 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,6 +1,5 @@ import time from typing import List -import asyncio from .dynamic_batching.infer_batch import InferBatch from .dynamic_batching.io_struct import Batch, Req @@ -9,8 +8,6 @@ from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine -from transformers import AutoTokenizer -_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" class DynamicBatchManager: def __init__( @@ -57,20 +54,6 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques self.req_queue.append(req) return - def add_input(self, request_id, sampling_params, input_ids): - """ - Encode and Add new input to req queue. support one sequence input for now. - """ - prompt_ids = self.tokenizer.encode(input_ids) - prompt_len = len(prompt_ids) - if prompt_len > self.engine.max_input_len: - raise ValueError( - f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}" - ) - sampling_params.stop_sentences_to_token_ids(self.tokenizer) - self.add_req(prompt_ids, sampling_params, request_id) - return - def abort(self, request_id): if self.running_batch is not None: for req in self.running_batch.reqs: @@ -83,15 +66,13 @@ def abort(self, request_id): req.aborted = True return - async def loop_for_fwd(self): + def loop_for_fwd(self): """ The main loop for a dynamic batching process. """ counter_count = 0 - #self.running_batch is not None or self.req_queue.waiting_req_list - while True: - async for item in self._step(): - yield item + while self.running_batch is not None or self.req_queue.waiting_req_list: + self._step() counter_count += 1 if self.running_batch is not None: if counter_count % self.mem_usage_interval == 0: @@ -106,26 +87,6 @@ async def loop_for_fwd(self): if self.running_batch is None: time.sleep(0.1) # 10ms - def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): - if tokenizer is not None: - self.tokenizer = tokenizer - else: - if "llama" in tokenizer_name.lower() and use_fast == True: - print( - "For some LLaMA-based models, initializing the fast tokenizer may " - "take a long time. To eliminate the initialization time, consider " - f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " - "tokenizer. This is done automatically in Colossalai.") - - tokenizer_name = _FAST_LLAMA_TOKENIZER - - try: - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - except TypeError as e: - use_fast = False - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - - def _step(self): """ Logic for handling requests @@ -136,14 +97,14 @@ def _step(self): if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch - yield from self._prefill_batch(self.running_batch) + self._prefill_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) - yield from self._decode_batch(self.running_batch) + self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 return @@ -151,18 +112,17 @@ def _step(self): new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) - yield from self._prefill_batch(new_mini_batch) + self._prefill_batch(new_mini_batch) if not new_mini_batch.is_clear(): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 - else: self.stats_tool.count_output_tokens(self.running_batch) - yield from self._decode_batch(self.running_batch) + self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 - + return def _init_batch(self, batch: Batch, dtype="fp16"): @@ -198,8 +158,7 @@ def _prefill_batch(self, batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - yield from self._handle_finish_req(batch, has_new_finished_req) - + self._handle_finish_req(batch, has_new_finished_req) # delete finished reqs def _decode_batch(self, batch: Batch): @@ -210,7 +169,7 @@ def _decode_batch(self, batch: Batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - yield from self._handle_finish_req(batch, has_new_finished_req) + self._handle_finish_req(batch, has_new_finished_req) def _filter_batch(self, batch: Batch): batch_id = batch.batch_id @@ -242,13 +201,11 @@ def _remove_batch(self, batch): def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: - finished_reqs=batch.filter_finished() + batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) - yield from self._output_process(finished_reqs) - def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): @@ -261,47 +218,26 @@ def _add_token_id_to_req(self, batch: Batch, req_ans): req.output_metadata_list.append(new_gen_metadata) return - async def _output_process(self, finished_reqs: List[Req]): - """ - Process the output of a batch. - """ - for req in finished_reqs: - output = self.tokenizer.decode(req.output_ids) - yield output, req.request_id, req.output_metadata_list - def clean_up(self): # this logic should be implemented in the future. pass - async def generate(self,request_id,prompt_id,sampling_params): - """ - Generate the output of a request. - """ - self.add_input(request_id,prompt_id,sampling_params) - def start_dynamic_batching(args, tp_engine, waiting_req_list): - try: - batch_manager = DynamicBatchManager( - tp_engine=tp_engine, - max_total_token_num=args.max_total_token_num, - batch_max_tokens=args.batch_max_tokens, - eos_id=args.eos_id, - log_stats=not args.disable_log_stats, - log_stats_interval=args.log_stats_interval, - waiting_req_list=waiting_req_list, - ) - - except Exception: - batch_manager.clean_up() - raise - - batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__) - prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world")) - - asyncio.run(prod_task) - - for item in batch_manager.loop_for_fwd(): - print(item) - - return batch_manager + # try: + batch_manager = DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + # except Exception: + # batch_manager.clean_up() + # raise + + batch_manager.loop_for_fwd() + return diff --git a/colossalai/inference/test_async.py b/colossalai/inference/test_async.py deleted file mode 100644 index 08720f36da22..000000000000 --- a/colossalai/inference/test_async.py +++ /dev/null @@ -1,33 +0,0 @@ -import asyncio - -shared_list = [] - -async def producer(): - for i in range(5): - await asyncio.sleep(1) # 模拟异步获取数据的操作 - shared_list.append(i) - print(f"Produced {i}") - -async def consumer(): - last_index = 0 - while True: - await asyncio.sleep(0.5) # 为了不使循环过于紧凑,增加了小的延迟 - if last_index < len(shared_list): - item = shared_list[last_index] - print(f"Consumed {item}") - yield item - last_index += 1 - -async def main(): - # 创建生产者和消费者任务 - prod_task = asyncio.create_task(producer()) - - # 等待生产者任务完成 - await prod_task - - async for data in consumer(): - print(data) - # 为了示例的目的,我们只等待一段时间,然后停止消费者 - await asyncio.sleep(5) - -asyncio.run(main()) diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py index ca6401259831..63df491e5b52 100644 --- a/tests/test_infer/test_dynamic_batching/test_forward.py +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -42,7 +42,7 @@ def run(): waiting_list.append(req2) waiting_list.append(req3) waiting_list.append(req4) - + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) model = LlamaForCausalLM(llama_config) model = model.half() @@ -50,13 +50,7 @@ def run(): shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) - manager._set_tokenizer(tokenizer_name = model.__class__.__name__) - result_generator = manager.loop_for_fwd() - for result in result_generator: - print(result) - - + start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) def check_dynamic_forward(rank, world_size, port): From d97290af8ade2dcd433cdd8e757fd494fe0edd7b Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Sat, 14 Oct 2023 13:53:51 +0800 Subject: [PATCH 07/32] Add Ray Distributed Environment Init Scripts --- .../inference/dynamic_batching/io_struct.py | 11 +- .../dynamic_batching/ray_dist_init.py | 115 +++++++++++++++++ .../dynamic_batching/ray_init_config.py | 53 ++++++++ colossalai/inference/manager.py | 121 ++++++++++++++---- .../test_dynamic_batching/config.yaml | 15 +++ .../test_dynamic_batching/test_ray_dist.py | 30 +++++ 6 files changed, 314 insertions(+), 31 deletions(-) create mode 100644 colossalai/inference/dynamic_batching/ray_dist_init.py create mode 100644 colossalai/inference/dynamic_batching/ray_init_config.py create mode 100644 tests/test_infer/test_dynamic_batching/config.yaml create mode 100644 tests/test_infer/test_dynamic_batching/test_ray_dist.py diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 2b2739f0ae90..63165d0a3e5a 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -4,7 +4,7 @@ class Req: - def __init__(self, request_id, prompt_ids, sample_params: SamplingParams): + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str): self.request_id = request_id self.prompt_ids = prompt_ids self.input_len = len(prompt_ids) @@ -14,6 +14,7 @@ def __init__(self, request_id, prompt_ids, sample_params: SamplingParams): self.output_metadata_list = [] self.has_generate_finished = False self.aborted = False + self.prompts = prompts def to_rpc_obj(self): return { @@ -102,17 +103,21 @@ def mark_finished_req(self, eos_id): has_new_finish = True return has_new_finish - def filter_finished(self): + def filter_finished(self)->List[Req]: """ Filter finished requests from the batch, the finished ones will be removed from 'reqs'. """ # TODO: the logic of return should be defined here. unfinished_req = [] + finished_req = [] for req in self.reqs: if not req.has_generate_finished: - unfinished_req.append(req) + unfinished_req.append(req) + else: + finished_req.append(req) self.reqs = unfinished_req self.id_to_reqs = {req.request_id: req for req in self.reqs} + return finished_req def is_clear(self): return len(self.reqs) == 0 diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py new file mode 100644 index 000000000000..0359d162f138 --- /dev/null +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -0,0 +1,115 @@ +import logging +import os + +import ray +import ray.util.collective as collective +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import free_port + +from colossalai.inference.manager import start_dynamic_batching +from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass + +ray_serve_logger = logging.getLogger("ray.serve") + +def log_cuda_info(scope_name: str): + ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}") + ray_serve_logger.info( + f" {scope_name}: CUDA_VISIBLE_DEVICES: {os.getenv('CUDA_VISIBLE_DEVICES', 'NO DEVICES FOUND!')}" + ) + if torch.cuda.is_available(): + ray_serve_logger.info( + f" {scope_name}: cuda current_device: {torch.cuda.current_device()}, cuda device count: {torch.cuda.device_count()}" + ) + else: + ray_serve_logger.info(f" {scope_name}: cuda is not available!") + +@ray.remote(num_gpus=1) +class Worker: + def __init__(self, model_path: str, tensor_parallel_size: int, max_batch_size: int, max_input_len: int, max_output_len: int, router_config: RooterArgsClass): + log_cuda_info("Worker.init") + self.tensor_parallel_size = tensor_parallel_size + self.model_path = model_path + self.max_batch_size = max_batch_size + self.max_input_len = max_input_len + self.max_output_len = max_output_len + self.router_config = router_config + + def setup(self, world_size, rank, port): + + # initialize a ray collective group, otherwise colossalai distributed env won't be built successfully + collective.init_collective_group(world_size, rank, "nccl", "default") + # initialize and set distributed environment + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..") + log_cuda_info("Worker.setup") + + # Load model + self.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16 + ) + + shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True) + self.infer_engine = TPInferEngine( + self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len + ) + self.start_dynamic_batching = start_dynamic_batching(self.router_config, self.infer_engine, []) + + return True + + def generate(self, request_id, prompt, sampling_params) -> str: + + ray_serve_logger.info(f"text: {prompt}") + + results_generator = self.start_dynamic_batching.generate(prompt, sampling_params, request_id) + + final_output = None + for request_output in results_generator: + final_output = request_output + + assert final_output is not None + ray_serve_logger.info(f"Generated text: {final_output}") + return final_output + +class Driver: + def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClass): + log_cuda_info("Driver:init") + model_path = engine_config.model + tensor_parallel_size = engine_config.tensor_parallel_size + + self.num_workers = tensor_parallel_size + self.workers = [] + init_rets = [] + + # Just grab a free port on localhost + # NOTE workers in this communication group listen to the same port + available_port = free_port() + + for i in range(self.num_workers): + worker_name = "worker_idx_{}".format(i) + w = Worker.options(name=worker_name).remote( + model_path, self.num_workers, engine_config.max_batch_size, engine_config.max_input_len, engine_config.max_output_len, router_config + ) + self.workers.append(w) + init_rets.append(w.setup.remote(self.num_workers, i, available_port)) + _options = { + "group_name": "default_driver", + "world_size": self.num_workers, + "ranks": [i for i in range(self.num_workers)], + "backend": "nccl", + } + collective.create_collective_group(self.workers, **_options) + _ = ray.get(init_rets) + + # set batch wait delay in seconds and maximum number of sequences in a batch + def generate(self, request_id, prompt, sampling_params): + results = ray.get([w.generate.remote(request_id, prompt, sampling_params) for w in self.workers]) + text_res = results[0] # get any one of the copies + return text_res \ No newline at end of file diff --git a/colossalai/inference/dynamic_batching/ray_init_config.py b/colossalai/inference/dynamic_batching/ray_init_config.py new file mode 100644 index 000000000000..0e89d759e987 --- /dev/null +++ b/colossalai/inference/dynamic_batching/ray_init_config.py @@ -0,0 +1,53 @@ +import logging + +import yaml +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +class EngineArgsClass(BaseModel): + """Config for Engine""" + model: str + tensor_parallel_size: int = 2 + max_batch_size: int = 4 + max_input_len: int = 128 + max_output_len: int = 32 + +class RooterArgsClass(BaseModel): + """Config for Rooter""" + max_total_token_num: int = 42 + batch_max_tokens: int = 42 + eos_id: int = 0 + disable_log_stats: bool = False + log_stats_interval: int = 10 + model: str + +class RayInitConfig(BaseModel): + """All-together configs without app router config""" + + engine_config_data: EngineArgsClass + router_config_data: RooterArgsClass + + @classmethod + def from_yaml_path(cls, path: str): + try: + with open(path, "r") as yaml_file: + try: + config = yaml.safe_load(yaml_file) + # serve deployment config + engine_config = config.get("engine_config", {}) + router_config = config.get("router_config", {}) + + return cls( + engine_config_data=engine_config, + router_config_data=router_config, + ) + except yaml.YAMLError as e: + logger.error(f"An Error occurred when parsing yaml: {e}") + raise + except FileNotFoundError: + logger.error(f"The file '{path}' does not exist!") + raise + except OSError as e: + logger.error(f"An Error occurred: {e}") + raise diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 72f77406789f..29af3ae1f934 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -8,6 +8,8 @@ from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine +from transformers import AutoTokenizer +_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" class DynamicBatchManager: def __init__( @@ -16,6 +18,7 @@ def __init__( max_total_token_num, batch_max_tokens, eos_id, + model, log_stats=True, log_stats_interval=10, running_batch: Batch = None, @@ -27,6 +30,7 @@ def __init__( batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine eos_id : The end token of a seq + model: the model weight dir path, the app will load config, weights and tokenizer from this dir log_stats : whether to log stats log_stats_interval : log stats interval running_batch : running batch @@ -42,18 +46,35 @@ def __init__( self.eos_id = eos_id self.has_wait_tokens = 0 self.max_wait_tokens = 10 + self.model = model self.stats_tool = Stats(log_stats, log_stats_interval) self.mem_usage_interval = log_stats_interval * 2 + self._set_tokenizer(tokenizer_name=self.model) - def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str): + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str): """ Add new request to req queue, during initialization all requests are held in waiting list. """ - req = Req(request_id, prompt_ids, sampling_params) + req = Req(request_id, prompt_ids, sampling_params, prompts) self.req_queue.append(req) + print("len(self.req_queue): ", len(self.req_queue)) return + def add_input(self, request_id, sampling_params, prompts): + """ + Encode and Add new input to req queue. support one sequence input for now. + """ + prompt_ids = self.tokenizer.encode(prompts) + prompt_len = len(prompt_ids) + if prompt_len > self.engine.max_input_len: + raise ValueError( + f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}" + ) + sampling_params.stop_sentences_to_token_ids(self.tokenizer) + self.add_req(prompt_ids, sampling_params, request_id, prompts) + return + def abort(self, request_id): if self.running_batch is not None: for req in self.running_batch.reqs: @@ -71,8 +92,14 @@ def loop_for_fwd(self): The main loop for a dynamic batching process. """ counter_count = 0 + #self.running_batch is not None or self.req_queue.waiting_req_list while self.running_batch is not None or self.req_queue.waiting_req_list: - self._step() + if self.running_batch is not None : + print("len(self.running_batch): ", len(self.running_batch)) + else: + print("len(self.running_batch): ", 0) + print("len(self.req_queue.waiting_req_list): ", len(self.req_queue.waiting_req_list)) + yield from self._step() counter_count += 1 if self.running_batch is not None: if counter_count % self.mem_usage_interval == 0: @@ -87,6 +114,26 @@ def loop_for_fwd(self): if self.running_batch is None: time.sleep(0.1) # 10ms + def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): + if tokenizer is not None: + self.tokenizer = tokenizer + else: + if "llama" in tokenizer_name.lower() and use_fast == True: + print( + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai.") + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + except TypeError as e: + use_fast = False + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + + def _step(self): """ Logic for handling requests @@ -97,14 +144,14 @@ def _step(self): if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch - self._prefill_batch(self.running_batch) + yield from self._prefill_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) - self._decode_batch(self.running_batch) + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 return @@ -112,17 +159,18 @@ def _step(self): new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) - self._prefill_batch(new_mini_batch) + yield from self._prefill_batch(new_mini_batch) if not new_mini_batch.is_clear(): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 + else: self.stats_tool.count_output_tokens(self.running_batch) - self._decode_batch(self.running_batch) + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 - + return def _init_batch(self, batch: Batch, dtype="fp16"): @@ -158,7 +206,8 @@ def _prefill_batch(self, batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - self._handle_finish_req(batch, has_new_finished_req) + yield from self._handle_finish_req(batch, has_new_finished_req) + # delete finished reqs def _decode_batch(self, batch: Batch): @@ -169,7 +218,7 @@ def _decode_batch(self, batch: Batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - self._handle_finish_req(batch, has_new_finished_req) + yield from self._handle_finish_req(batch, has_new_finished_req) def _filter_batch(self, batch: Batch): batch_id = batch.batch_id @@ -201,11 +250,13 @@ def _remove_batch(self, batch): def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: - batch.filter_finished() + finished_reqs=batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) + yield from self._output_process(finished_reqs) + def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): @@ -218,26 +269,40 @@ def _add_token_id_to_req(self, batch: Batch, req_ans): req.output_metadata_list.append(new_gen_metadata) return + def _output_process(self, finished_reqs: List[Req]): + """ + Process the output of a batch. + """ + for req in finished_reqs: + output = self.tokenizer.decode(req.output_ids) + yield req.prompts + output + def clean_up(self): # this logic should be implemented in the future. pass + def generate(self,prompts,sampling_params,request_id): + """ + Generate the output of a request. + """ + self.add_input(request_id,sampling_params,prompts) + return self.loop_for_fwd() def start_dynamic_batching(args, tp_engine, waiting_req_list): - # try: - batch_manager = DynamicBatchManager( - tp_engine=tp_engine, - max_total_token_num=args.max_total_token_num, - batch_max_tokens=args.batch_max_tokens, - eos_id=args.eos_id, - log_stats=not args.disable_log_stats, - log_stats_interval=args.log_stats_interval, - waiting_req_list=waiting_req_list, - ) - - # except Exception: - # batch_manager.clean_up() - # raise - - batch_manager.loop_for_fwd() - return + try: + batch_manager = DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + model=args.model, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + except Exception: + batch_manager.clean_up() + raise + + return batch_manager diff --git a/tests/test_infer/test_dynamic_batching/config.yaml b/tests/test_infer/test_dynamic_batching/config.yaml new file mode 100644 index 000000000000..0129f036a00f --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/config.yaml @@ -0,0 +1,15 @@ +engine_config: + model: /home/lccd/share/model_data/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348 + tensor_parallel_size: 2 + max_batch_size: 4 + max_input_len: 128 + max_output_len: 32 +# config for app router deployment +# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig? +router_config: + max_total_token_num: 42 + batch_max_tokens: 42 + eos_id: 0 + disable_log_stats: False + log_stats_interval: 10 + model: /home/lccd/share/model_data/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348 diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py new file mode 100644 index 000000000000..d889db44b277 --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -0,0 +1,30 @@ +import os +from typing import Dict +import uuid +from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig +from colossalai.inference.dynamic_batching.ray_dist_init import Driver +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams + +def test_ray_dist(path: str): + print(f"Using yaml file {path}") + if not os.path.exists(path): + raise FileNotFoundError(f"Invalid yaml file path {path}") + config = RayInitConfig.from_yaml_path(path) + router_config = config.router_config_data + engine_config = config.engine_config_data + model = engine_config.model + if model is None or not os.path.exists(model): + raise ValueError("Model path not provided or invalid path!") + + driver = Driver(router_config=router_config, engine_config=engine_config) + prompt = 'Introduce some landmarks in Beijing' + + request_id = str(uuid.uuid4().hex) + + sampling_params = SamplingParams() + + print("result: ", prompt + driver.generate(request_id, prompt, sampling_params)) + +if __name__ == "__main__": + path = "config.yaml" + test_ray_dist(path) From f589e97c94c01dd88054a4df007ed19365d9e77c Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Sat, 14 Oct 2023 19:23:39 +0800 Subject: [PATCH 08/32] support DynamicBatchManager base function --- .../dynamic_batching/ray_dist_init.py | 33 +++++++++++++++++-- colossalai/inference/manager.py | 6 ---- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index 0359d162f138..9701ca2cde5a 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -13,6 +13,8 @@ from colossalai.inference.manager import start_dynamic_batching from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from typing import List ray_serve_logger = logging.getLogger("ray.serve") @@ -64,7 +66,7 @@ def setup(self, world_size, rank, port): return True - def generate(self, request_id, prompt, sampling_params) -> str: + def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> str: ray_serve_logger.info(f"text: {prompt}") @@ -77,6 +79,19 @@ def generate(self, request_id, prompt, sampling_params) -> str: assert final_output is not None ray_serve_logger.info(f"Generated text: {final_output}") return final_output + + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): + self.start_dynamic_batching.add_input(request_id, sampling_params, prompt) + + def abort(self,request_id: str): + self.start_dynamic_batching.abort(request_id) + + def step(self): + self.start_dynamic_batching._step() + + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): + self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt) + class Driver: def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClass): @@ -109,7 +124,19 @@ def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClas _ = ray.get(init_rets) # set batch wait delay in seconds and maximum number of sequences in a batch - def generate(self, request_id, prompt, sampling_params): + def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): results = ray.get([w.generate.remote(request_id, prompt, sampling_params) for w in self.workers]) text_res = results[0] # get any one of the copies - return text_res \ No newline at end of file + return text_res + + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): + ray.get([w.add_input.remote(request_id, sampling_params, prompt) for w in self.workers]) + + def abort(self,request_id: str): + ray.get([w.abort.remote(request_id) for w in self.workers]) + + def step(self): + ray.get([w._step.remote() for w in self.workers]) + + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): + ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) \ No newline at end of file diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index b5ee1d027e37..6678ecae0816 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -56,7 +56,6 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques """ req = Req(request_id, prompt_ids, sampling_params, prompts) self.req_queue.append(req) - print("len(self.req_queue): ", len(self.req_queue)) return def add_input(self, request_id, sampling_params, prompts): @@ -92,11 +91,6 @@ def loop_for_fwd(self): counter_count = 0 #self.running_batch is not None or self.req_queue.waiting_req_list while self.running_batch is not None or self.req_queue.waiting_req_list: - if self.running_batch is not None : - print("len(self.running_batch): ", len(self.running_batch)) - else: - print("len(self.running_batch): ", 0) - print("len(self.req_queue.waiting_req_list): ", len(self.req_queue.waiting_req_list)) yield from self._step() counter_count += 1 if self.running_batch is not None: From c07005074af453f0c59d34dfdb4306fe7fbd29ff Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 11:55:22 +0800 Subject: [PATCH 09/32] revert _set_tokenizer version --- colossalai/inference/manager.py | 42 ++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 6678ecae0816..06eae3ec0ce3 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -8,6 +8,8 @@ from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine +from transformers import AutoTokenizer +_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" class DynamicBatchManager: def __init__( @@ -106,6 +108,26 @@ def loop_for_fwd(self): if self.running_batch is None: time.sleep(0.1) # 10ms + def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): + if tokenizer is not None: + self.tokenizer = tokenizer + else: + if "llama" in tokenizer_name.lower() and use_fast == True: + print( + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai.") + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + except TypeError as e: + use_fast = False + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + + def _step(self): """ Logic for handling requests @@ -116,14 +138,14 @@ def _step(self): if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch - self._prefill_batch(self.running_batch) + yield from self._prefill_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) - self._decode_batch(self.running_batch) + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 return @@ -131,17 +153,18 @@ def _step(self): new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) - self._prefill_batch(new_mini_batch) + yield from self._prefill_batch(new_mini_batch) if not new_mini_batch.is_clear(): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 + else: self.stats_tool.count_output_tokens(self.running_batch) - self._decode_batch(self.running_batch) + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 - + return def _init_batch(self, batch: Batch, dtype="fp16"): @@ -177,7 +200,8 @@ def _prefill_batch(self, batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - self._handle_finish_req(batch, has_new_finished_req) + yield from self._handle_finish_req(batch, has_new_finished_req) + # delete finished reqs def _decode_batch(self, batch: Batch): @@ -188,7 +212,7 @@ def _decode_batch(self, batch: Batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - self._handle_finish_req(batch, has_new_finished_req) + yield from self._handle_finish_req(batch, has_new_finished_req) def _filter_batch(self, batch: Batch): batch_id = batch.batch_id @@ -220,11 +244,13 @@ def _remove_batch(self, batch): def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: - batch.filter_finished() + finished_reqs=batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) + yield from self._output_process(finished_reqs) + def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): From 5deb95ced8c76fc1b49ef6b24bd0365d1d59459f Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 14:03:06 +0800 Subject: [PATCH 10/32] add driver async generate --- .../dynamic_batching/ray_dist_init.py | 14 ++++++++-- colossalai/inference/manager.py | 26 ++----------------- .../test_dynamic_batching/test_ray_dist.py | 5 +++- 3 files changed, 18 insertions(+), 27 deletions(-) diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index 9701ca2cde5a..63cf8f33c7a8 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -4,7 +4,7 @@ import ray import ray.util.collective as collective import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM import colossalai from colossalai.inference.tensor_parallel.engine import TPInferEngine @@ -14,7 +14,9 @@ from colossalai.inference.manager import start_dynamic_batching from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer from typing import List +import asyncio ray_serve_logger = logging.getLogger("ray.serve") @@ -51,7 +53,7 @@ def setup(self, world_size, rank, port): log_cuda_info("Worker.setup") # Load model - self.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + self.tokenizer = get_tokenizer(tokenizer_name = self.model_path) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( @@ -129,6 +131,14 @@ def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams text_res = results[0] # get any one of the copies return text_res + async def async_generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): + all_outputs = [] + for worker in self.workers: + all_outputs.append(worker.generate.remote(request_id, prompt, sampling_params)) + all_outputs = await asyncio.gather(*all_outputs) + text_res = all_outputs[0]# get any one of the copies + return text_res + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): ray.get([w.add_input.remote(request_id, sampling_params, prompt) for w in self.workers]) diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 06eae3ec0ce3..26d93eb1f14a 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -7,9 +7,7 @@ from .dynamic_batching.sampling_params import SamplingParams from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine - -from transformers import AutoTokenizer -_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" +from .dynamic_batching.get_tokenizer import get_tokenizer class DynamicBatchManager: def __init__( @@ -50,7 +48,7 @@ def __init__( self.stats_tool = Stats(log_stats, log_stats_interval) self.mem_usage_interval = log_stats_interval * 2 - self._set_tokenizer(tokenizer_name=self.model) + self.tokenizer = get_tokenizer(tokenizer_name=self.model) def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str): """ @@ -108,26 +106,6 @@ def loop_for_fwd(self): if self.running_batch is None: time.sleep(0.1) # 10ms - def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): - if tokenizer is not None: - self.tokenizer = tokenizer - else: - if "llama" in tokenizer_name.lower() and use_fast == True: - print( - "For some LLaMA-based models, initializing the fast tokenizer may " - "take a long time. To eliminate the initialization time, consider " - f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " - "tokenizer. This is done automatically in Colossalai.") - - tokenizer_name = _FAST_LLAMA_TOKENIZER - - try: - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - except TypeError as e: - use_fast = False - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - - def _step(self): """ Logic for handling requests diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index d889db44b277..c943c74eb456 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -23,7 +23,10 @@ def test_ray_dist(path: str): sampling_params = SamplingParams() - print("result: ", prompt + driver.generate(request_id, prompt, sampling_params)) + result_generator = driver.generate(request_id, prompt, sampling_params) + + for result in result_generator: + print("result: ", result) if __name__ == "__main__": path = "config.yaml" From 306ef77a0c09fd3224ab1dc400ca313ef3db49ef Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 14:38:31 +0800 Subject: [PATCH 11/32] add async test --- .../inference/dynamic_batching/io_struct.py | 8 +++++-- .../test_dynamic_batching/test_ray_dist.py | 21 +++++++++++++------ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index fe5f25e2ea11..63165d0a3e5a 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -103,17 +103,21 @@ def mark_finished_req(self, eos_id): has_new_finish = True return has_new_finish - def filter_finished(self): + def filter_finished(self)->List[Req]: """ Filter finished requests from the batch, the finished ones will be removed from 'reqs'. """ # TODO: the logic of return should be defined here. unfinished_req = [] + finished_req = [] for req in self.reqs: if not req.has_generate_finished: - unfinished_req.append(req) + unfinished_req.append(req) + else: + finished_req.append(req) self.reqs = unfinished_req self.id_to_reqs = {req.request_id: req for req in self.reqs} + return finished_req def is_clear(self): return len(self.reqs) == 0 diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index c943c74eb456..a7bc7b2df246 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -1,9 +1,9 @@ import os -from typing import Dict import uuid from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig from colossalai.inference.dynamic_batching.ray_dist_init import Driver from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +import asyncio def test_ray_dist(path: str): print(f"Using yaml file {path}") @@ -23,11 +23,20 @@ def test_ray_dist(path: str): sampling_params = SamplingParams() - result_generator = driver.generate(request_id, prompt, sampling_params) - - for result in result_generator: - print("result: ", result) + async def get_result(request_id, prompt, sampling_params): + return await driver.generate(request_id, prompt, sampling_params) + + for test_async in [True, False]: + if test_async: + print("test_async: ", test_async) + result = get_result(request_id, prompt, sampling_params) + print("result: ", result) + else: + print("test_async: ", test_async) + result = driver.generate(request_id, prompt, sampling_params) + print("result: ", result) + if __name__ == "__main__": path = "config.yaml" - test_ray_dist(path) + test_ray_dist(path) \ No newline at end of file From 632f0e1107f4d454e5e7f28be72b14b455acd562 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 14:52:56 +0800 Subject: [PATCH 12/32] fix bugs in test_ray_dist.py --- tests/test_infer/test_dynamic_batching/test_ray_dist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index a7bc7b2df246..9bf5ff68b6ae 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -24,12 +24,12 @@ def test_ray_dist(path: str): sampling_params = SamplingParams() async def get_result(request_id, prompt, sampling_params): - return await driver.generate(request_id, prompt, sampling_params) + return await driver.async_generate(request_id, prompt, sampling_params) for test_async in [True, False]: if test_async: print("test_async: ", test_async) - result = get_result(request_id, prompt, sampling_params) + result = asyncio.run(get_result(request_id, prompt, sampling_params)) print("result: ", result) else: print("test_async: ", test_async) From 0b2fe513f3aad5f5b8092ed3989af994caab774b Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 15:10:36 +0800 Subject: [PATCH 13/32] add get_tokenizer.py --- .../dynamic_batching/get_tokenizer.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 colossalai/inference/dynamic_batching/get_tokenizer.py diff --git a/colossalai/inference/dynamic_batching/get_tokenizer.py b/colossalai/inference/dynamic_batching/get_tokenizer.py new file mode 100644 index 000000000000..ea8116ce66f5 --- /dev/null +++ b/colossalai/inference/dynamic_batching/get_tokenizer.py @@ -0,0 +1,22 @@ +from transformers import AutoTokenizer +_FAST_LLAMA_TOKENIZER = "/home/lccd/share/llama-tokenizer" + +def get_tokenizer(tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): + if tokenizer is not None: + tokenizer = tokenizer + else: + if "llama" in tokenizer_name.lower() and use_fast == True: + print( + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai.") + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + except TypeError as e: + use_fast = False + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) + return tokenizer \ No newline at end of file From cd843ac8f2fc7e9cedee641a4f110b29f03e9d84 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 15:35:38 +0800 Subject: [PATCH 14/32] fix code style --- .../dynamic_batching/get_tokenizer.py | 42 +++++++----- .../inference/dynamic_batching/io_struct.py | 12 ++-- .../dynamic_batching/ray_dist_init.py | 64 +++++++++++-------- .../dynamic_batching/ray_init_config.py | 5 ++ colossalai/inference/manager.py | 27 ++++---- .../test_dynamic_batching/config.yaml | 4 +- .../test_dynamic_batching/test_ray_dist.py | 30 +++++---- 7 files changed, 109 insertions(+), 75 deletions(-) diff --git a/colossalai/inference/dynamic_batching/get_tokenizer.py b/colossalai/inference/dynamic_batching/get_tokenizer.py index ea8116ce66f5..af1f26848b3a 100644 --- a/colossalai/inference/dynamic_batching/get_tokenizer.py +++ b/colossalai/inference/dynamic_batching/get_tokenizer.py @@ -1,22 +1,34 @@ from transformers import AutoTokenizer + _FAST_LLAMA_TOKENIZER = "/home/lccd/share/llama-tokenizer" - -def get_tokenizer(tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,): + + +def get_tokenizer( + tokenizer=None, + tokenizer_name: str = "", + trust_remote_code: bool = False, + use_fast: bool = True, +): if tokenizer is not None: - tokenizer = tokenizer + tokenizer = tokenizer else: if "llama" in tokenizer_name.lower() and use_fast == True: print( - "For some LLaMA-based models, initializing the fast tokenizer may " - "take a long time. To eliminate the initialization time, consider " - f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " - "tokenizer. This is done automatically in Colossalai.") - - tokenizer_name = _FAST_LLAMA_TOKENIZER - - try: - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - except TypeError as e: + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai." + ) + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) + except TypeError: use_fast = False - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code) - return tokenizer \ No newline at end of file + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) + return tokenizer diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 63165d0a3e5a..9faaad6f111e 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -37,7 +37,11 @@ def stop_sequences_matched(self): if self.sample_params.stop_sequences is not None: for stop_token_ids in self.sample_params.stop_sequences: stop_len = len(stop_token_ids) - if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)): + if ( + stop_len > 0 + and len(self.output_ids) >= stop_len + and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)) + ): return True return False @@ -103,7 +107,7 @@ def mark_finished_req(self, eos_id): has_new_finish = True return has_new_finish - def filter_finished(self)->List[Req]: + def filter_finished(self) -> List[Req]: """ Filter finished requests from the batch, the finished ones will be removed from 'reqs'. """ @@ -112,9 +116,9 @@ def filter_finished(self)->List[Req]: finished_req = [] for req in self.reqs: if not req.has_generate_finished: - unfinished_req.append(req) + unfinished_req.append(req) else: - finished_req.append(req) + finished_req.append(req) self.reqs = unfinished_req self.id_to_reqs = {req.request_id: req for req in self.reqs} return finished_req diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index 63cf8f33c7a8..a40a00e2666c 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -1,5 +1,7 @@ +import asyncio import logging import os +from typing import List import ray import ray.util.collective as collective @@ -7,19 +9,17 @@ from transformers import AutoModelForCausalLM import colossalai +from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer +from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.manager import start_dynamic_batching from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.shardformer import ShardConfig from colossalai.testing import free_port -from colossalai.inference.manager import start_dynamic_batching -from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass -from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer -from typing import List -import asyncio - ray_serve_logger = logging.getLogger("ray.serve") + def log_cuda_info(scope_name: str): ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}") ray_serve_logger.info( @@ -32,9 +32,18 @@ def log_cuda_info(scope_name: str): else: ray_serve_logger.info(f" {scope_name}: cuda is not available!") + @ray.remote(num_gpus=1) class Worker: - def __init__(self, model_path: str, tensor_parallel_size: int, max_batch_size: int, max_input_len: int, max_output_len: int, router_config: RooterArgsClass): + def __init__( + self, + model_path: str, + tensor_parallel_size: int, + max_batch_size: int, + max_input_len: int, + max_output_len: int, + router_config: RooterArgsClass, + ): log_cuda_info("Worker.init") self.tensor_parallel_size = tensor_parallel_size self.model_path = model_path @@ -44,7 +53,6 @@ def __init__(self, model_path: str, tensor_parallel_size: int, max_batch_size: i self.router_config = router_config def setup(self, world_size, rank, port): - # initialize a ray collective group, otherwise colossalai distributed env won't be built successfully collective.init_collective_group(world_size, rank, "nccl", "default") # initialize and set distributed environment @@ -53,7 +61,7 @@ def setup(self, world_size, rank, port): log_cuda_info("Worker.setup") # Load model - self.tokenizer = get_tokenizer(tokenizer_name = self.model_path) + self.tokenizer = get_tokenizer(tokenizer_name=self.model_path) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( @@ -69,7 +77,6 @@ def setup(self, world_size, rank, port): return True def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> str: - ray_serve_logger.info(f"text: {prompt}") results_generator = self.start_dynamic_batching.generate(prompt, sampling_params, request_id) @@ -81,19 +88,19 @@ def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams assert final_output is not None ray_serve_logger.info(f"Generated text: {final_output}") return final_output - + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): self.start_dynamic_batching.add_input(request_id, sampling_params, prompt) - - def abort(self,request_id: str): + + def abort(self, request_id: str): self.start_dynamic_batching.abort(request_id) - + def step(self): self.start_dynamic_batching._step() - + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt) - + class Driver: def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClass): @@ -112,7 +119,12 @@ def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClas for i in range(self.num_workers): worker_name = "worker_idx_{}".format(i) w = Worker.options(name=worker_name).remote( - model_path, self.num_workers, engine_config.max_batch_size, engine_config.max_input_len, engine_config.max_output_len, router_config + model_path, + self.num_workers, + engine_config.max_batch_size, + engine_config.max_input_len, + engine_config.max_output_len, + router_config, ) self.workers.append(w) init_rets.append(w.setup.remote(self.num_workers, i, available_port)) @@ -130,23 +142,23 @@ def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams results = ray.get([w.generate.remote(request_id, prompt, sampling_params) for w in self.workers]) text_res = results[0] # get any one of the copies return text_res - + async def async_generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): all_outputs = [] for worker in self.workers: all_outputs.append(worker.generate.remote(request_id, prompt, sampling_params)) all_outputs = await asyncio.gather(*all_outputs) - text_res = all_outputs[0]# get any one of the copies + text_res = all_outputs[0] # get any one of the copies return text_res - + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): ray.get([w.add_input.remote(request_id, sampling_params, prompt) for w in self.workers]) - - def abort(self,request_id: str): + + def abort(self, request_id: str): ray.get([w.abort.remote(request_id) for w in self.workers]) - + def step(self): ray.get([w._step.remote() for w in self.workers]) - + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): - ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) \ No newline at end of file + ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) diff --git a/colossalai/inference/dynamic_batching/ray_init_config.py b/colossalai/inference/dynamic_batching/ray_init_config.py index 0e89d759e987..471f07330aec 100644 --- a/colossalai/inference/dynamic_batching/ray_init_config.py +++ b/colossalai/inference/dynamic_batching/ray_init_config.py @@ -5,16 +5,20 @@ logger = logging.getLogger(__name__) + class EngineArgsClass(BaseModel): """Config for Engine""" + model: str tensor_parallel_size: int = 2 max_batch_size: int = 4 max_input_len: int = 128 max_output_len: int = 32 + class RooterArgsClass(BaseModel): """Config for Rooter""" + max_total_token_num: int = 42 batch_max_tokens: int = 42 eos_id: int = 0 @@ -22,6 +26,7 @@ class RooterArgsClass(BaseModel): log_stats_interval: int = 10 model: str + class RayInitConfig(BaseModel): """All-together configs without app router config""" diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 26d93eb1f14a..30717a915e3b 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,13 +1,14 @@ import time from typing import List +from .dynamic_batching.get_tokenizer import get_tokenizer from .dynamic_batching.infer_batch import InferBatch from .dynamic_batching.io_struct import Batch, Req from .dynamic_batching.req_queue import ReqQueue from .dynamic_batching.sampling_params import SamplingParams from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine -from .dynamic_batching.get_tokenizer import get_tokenizer + class DynamicBatchManager: def __init__( @@ -45,7 +46,7 @@ def __init__( self.has_wait_tokens = 0 self.max_wait_tokens = 10 self.model = model - + self.stats_tool = Stats(log_stats, log_stats_interval) self.mem_usage_interval = log_stats_interval * 2 self.tokenizer = get_tokenizer(tokenizer_name=self.model) @@ -65,13 +66,11 @@ def add_input(self, request_id, sampling_params, prompts): prompt_ids = self.tokenizer.encode(prompts) prompt_len = len(prompt_ids) if prompt_len > self.engine.max_input_len: - raise ValueError( - f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}" - ) + raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}") sampling_params.stop_sentences_to_token_ids(self.tokenizer) self.add_req(prompt_ids, sampling_params, request_id, prompts) return - + def abort(self, request_id): if self.running_batch is not None: for req in self.running_batch.reqs: @@ -89,7 +88,7 @@ def loop_for_fwd(self): The main loop for a dynamic batching process. """ counter_count = 0 - #self.running_batch is not None or self.req_queue.waiting_req_list + # self.running_batch is not None or self.req_queue.waiting_req_list while self.running_batch is not None or self.req_queue.waiting_req_list: yield from self._step() counter_count += 1 @@ -136,13 +135,13 @@ def _step(self): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 - + else: self.stats_tool.count_output_tokens(self.running_batch) yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 - + return def _init_batch(self, batch: Batch, dtype="fp16"): @@ -179,7 +178,7 @@ def _prefill_batch(self, batch): self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) yield from self._handle_finish_req(batch, has_new_finished_req) - + # delete finished reqs def _decode_batch(self, batch: Batch): @@ -222,14 +221,13 @@ def _remove_batch(self, batch): def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: - finished_reqs=batch.filter_finished() + finished_reqs = batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) yield from self._output_process(finished_reqs) - def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): self.running_batch = None @@ -253,13 +251,14 @@ def clean_up(self): # this logic should be implemented in the future. pass - def generate(self,prompts,sampling_params,request_id): + def generate(self, prompts, sampling_params, request_id): """ Generate the output of a request. """ - self.add_input(request_id,sampling_params,prompts) + self.add_input(request_id, sampling_params, prompts) return self.loop_for_fwd() + def start_dynamic_batching(args, tp_engine, waiting_req_list): try: batch_manager = DynamicBatchManager( diff --git a/tests/test_infer/test_dynamic_batching/config.yaml b/tests/test_infer/test_dynamic_batching/config.yaml index 0129f036a00f..c31ae8c5fadb 100644 --- a/tests/test_infer/test_dynamic_batching/config.yaml +++ b/tests/test_infer/test_dynamic_batching/config.yaml @@ -1,5 +1,5 @@ engine_config: - model: /home/lccd/share/model_data/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348 + model: MODEL_PATH tensor_parallel_size: 2 max_batch_size: 4 max_input_len: 128 @@ -12,4 +12,4 @@ router_config: eos_id: 0 disable_log_stats: False log_stats_interval: 10 - model: /home/lccd/share/model_data/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348 + model: MODEL_PATH diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index 9bf5ff68b6ae..09f41ba137de 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -1,9 +1,11 @@ +import asyncio import os import uuid -from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig + from colossalai.inference.dynamic_batching.ray_dist_init import Driver +from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -import asyncio + def test_ray_dist(path: str): print(f"Using yaml file {path}") @@ -15,28 +17,28 @@ def test_ray_dist(path: str): model = engine_config.model if model is None or not os.path.exists(model): raise ValueError("Model path not provided or invalid path!") - + driver = Driver(router_config=router_config, engine_config=engine_config) - prompt = 'Introduce some landmarks in Beijing' - + prompt = "Introduce some landmarks in Beijing" + request_id = str(uuid.uuid4().hex) - + sampling_params = SamplingParams() - + async def get_result(request_id, prompt, sampling_params): return await driver.async_generate(request_id, prompt, sampling_params) - + for test_async in [True, False]: - if test_async: + if test_async: print("test_async: ", test_async) - result = asyncio.run(get_result(request_id, prompt, sampling_params)) + result = asyncio.run(get_result(request_id, prompt, sampling_params)) print("result: ", result) else: print("test_async: ", test_async) - result = driver.generate(request_id, prompt, sampling_params) + result = driver.generate(request_id, prompt, sampling_params) print("result: ", result) - - + + if __name__ == "__main__": path = "config.yaml" - test_ray_dist(path) \ No newline at end of file + test_ray_dist(path) From 8c9ad51484064055c7d8262d5455f641860cbd72 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 16:10:42 +0800 Subject: [PATCH 15/32] fix bugs about No module named 'pydantic' in ci test --- requirements/requirements-test.txt | 1 + requirements/requirements.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 467f83610eb0..e22c1d1a5127 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -18,4 +18,5 @@ SentencePiece ninja flash_attn==2.0.5 datasets +pydantic #auto-gptq now not support torch1.12 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 9aa5f2822e40..421784f3de87 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -11,3 +11,4 @@ ninja torch>=1.12 safetensors einops +pydantic From 8d0cc6b51a8690c1fca0f42b94a5c5dcf46ba70b Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 16:23:03 +0800 Subject: [PATCH 16/32] fix bugs in ci test --- requirements/requirements-test.txt | 1 + requirements/requirements.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index e22c1d1a5127..f54b13c7e43c 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -19,4 +19,5 @@ ninja flash_attn==2.0.5 datasets pydantic +ray #auto-gptq now not support torch1.12 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 421784f3de87..8a4b0f1a0ffd 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -12,3 +12,4 @@ torch>=1.12 safetensors einops pydantic +ray From acdd751a2fd080b6c0d61f538a686513fe7cb818 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 16:50:49 +0800 Subject: [PATCH 17/32] fix bugs in ci test --- .../test_dynamic_batching/test_ray_dist.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index 09f41ba137de..76e47c7eabd3 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -5,6 +5,9 @@ from colossalai.inference.dynamic_batching.ray_dist_init import Driver from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +import colossalai +import pytest +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn def test_ray_dist(path: str): @@ -16,8 +19,7 @@ def test_ray_dist(path: str): engine_config = config.engine_config_data model = engine_config.model if model is None or not os.path.exists(model): - raise ValueError("Model path not provided or invalid path!") - + return driver = Driver(router_config=router_config, engine_config=engine_config) prompt = "Introduce some landmarks in Beijing" @@ -38,6 +40,16 @@ async def get_result(request_id, prompt, sampling_params): result = driver.generate(request_id, prompt, sampling_params) print("result: ", result) +def check_dynamic_batching_manager(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + test_ray_dist() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_dynamic_batching_manager(): + spawn(check_dynamic_batching_manager, 1) if __name__ == "__main__": path = "config.yaml" From 8a761bdad4f3907f3fe42c680c1edb15c8ad342a Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 17:38:08 +0800 Subject: [PATCH 18/32] fix bugs in ci test --- .../test_dynamic_batching/test_ray_dist.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index 76e47c7eabd3..4cf9881f41dc 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -9,8 +9,9 @@ import pytest from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn +PATH = "config.yaml" -def test_ray_dist(path: str): +def run_ray_dist(path: str): print(f"Using yaml file {path}") if not os.path.exists(path): raise FileNotFoundError(f"Invalid yaml file path {path}") @@ -40,17 +41,16 @@ async def get_result(request_id, prompt, sampling_params): result = driver.generate(request_id, prompt, sampling_params) print("result: ", result) -def check_dynamic_batching_manager(rank, world_size, port): +def check_ray_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - test_ray_dist() + run_ray_dist(PATH) @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() -def test_dynamic_batching_manager(): - spawn(check_dynamic_batching_manager, 1) +def test_ray_dist(): + spawn(check_ray_dist, 1) if __name__ == "__main__": - path = "config.yaml" - test_ray_dist(path) + test_ray_dist() From 56f75c4aacfb473a6372af90243b61eef33f8dcc Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 16 Oct 2023 18:01:26 +0800 Subject: [PATCH 19/32] [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test --- .../dynamic_batching/get_tokenizer.py | 34 ++++ .../inference/dynamic_batching/io_struct.py | 15 +- .../dynamic_batching/ray_dist_init.py | 164 ++++++++++++++++++ .../dynamic_batching/ray_init_config.py | 58 +++++++ colossalai/inference/manager.py | 92 +++++++--- requirements/requirements-test.txt | 2 + requirements/requirements.txt | 2 + .../test_dynamic_batching/config.yaml | 15 ++ .../test_dynamic_batching/test_ray_dist.py | 56 ++++++ 9 files changed, 407 insertions(+), 31 deletions(-) create mode 100644 colossalai/inference/dynamic_batching/get_tokenizer.py create mode 100644 colossalai/inference/dynamic_batching/ray_dist_init.py create mode 100644 colossalai/inference/dynamic_batching/ray_init_config.py create mode 100644 tests/test_infer/test_dynamic_batching/config.yaml create mode 100644 tests/test_infer/test_dynamic_batching/test_ray_dist.py diff --git a/colossalai/inference/dynamic_batching/get_tokenizer.py b/colossalai/inference/dynamic_batching/get_tokenizer.py new file mode 100644 index 000000000000..af1f26848b3a --- /dev/null +++ b/colossalai/inference/dynamic_batching/get_tokenizer.py @@ -0,0 +1,34 @@ +from transformers import AutoTokenizer + +_FAST_LLAMA_TOKENIZER = "/home/lccd/share/llama-tokenizer" + + +def get_tokenizer( + tokenizer=None, + tokenizer_name: str = "", + trust_remote_code: bool = False, + use_fast: bool = True, +): + if tokenizer is not None: + tokenizer = tokenizer + else: + if "llama" in tokenizer_name.lower() and use_fast == True: + print( + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai." + ) + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) + except TypeError: + use_fast = False + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) + return tokenizer diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 2b2739f0ae90..9faaad6f111e 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -4,7 +4,7 @@ class Req: - def __init__(self, request_id, prompt_ids, sample_params: SamplingParams): + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str): self.request_id = request_id self.prompt_ids = prompt_ids self.input_len = len(prompt_ids) @@ -14,6 +14,7 @@ def __init__(self, request_id, prompt_ids, sample_params: SamplingParams): self.output_metadata_list = [] self.has_generate_finished = False self.aborted = False + self.prompts = prompts def to_rpc_obj(self): return { @@ -36,7 +37,11 @@ def stop_sequences_matched(self): if self.sample_params.stop_sequences is not None: for stop_token_ids in self.sample_params.stop_sequences: stop_len = len(stop_token_ids) - if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)): + if ( + stop_len > 0 + and len(self.output_ids) >= stop_len + and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)) + ): return True return False @@ -102,17 +107,21 @@ def mark_finished_req(self, eos_id): has_new_finish = True return has_new_finish - def filter_finished(self): + def filter_finished(self) -> List[Req]: """ Filter finished requests from the batch, the finished ones will be removed from 'reqs'. """ # TODO: the logic of return should be defined here. unfinished_req = [] + finished_req = [] for req in self.reqs: if not req.has_generate_finished: unfinished_req.append(req) + else: + finished_req.append(req) self.reqs = unfinished_req self.id_to_reqs = {req.request_id: req for req in self.reqs} + return finished_req def is_clear(self): return len(self.reqs) == 0 diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py new file mode 100644 index 000000000000..a40a00e2666c --- /dev/null +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -0,0 +1,164 @@ +import asyncio +import logging +import os +from typing import List + +import ray +import ray.util.collective as collective +import torch +from transformers import AutoModelForCausalLM + +import colossalai +from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer +from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.manager import start_dynamic_batching +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import free_port + +ray_serve_logger = logging.getLogger("ray.serve") + + +def log_cuda_info(scope_name: str): + ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}") + ray_serve_logger.info( + f" {scope_name}: CUDA_VISIBLE_DEVICES: {os.getenv('CUDA_VISIBLE_DEVICES', 'NO DEVICES FOUND!')}" + ) + if torch.cuda.is_available(): + ray_serve_logger.info( + f" {scope_name}: cuda current_device: {torch.cuda.current_device()}, cuda device count: {torch.cuda.device_count()}" + ) + else: + ray_serve_logger.info(f" {scope_name}: cuda is not available!") + + +@ray.remote(num_gpus=1) +class Worker: + def __init__( + self, + model_path: str, + tensor_parallel_size: int, + max_batch_size: int, + max_input_len: int, + max_output_len: int, + router_config: RooterArgsClass, + ): + log_cuda_info("Worker.init") + self.tensor_parallel_size = tensor_parallel_size + self.model_path = model_path + self.max_batch_size = max_batch_size + self.max_input_len = max_input_len + self.max_output_len = max_output_len + self.router_config = router_config + + def setup(self, world_size, rank, port): + # initialize a ray collective group, otherwise colossalai distributed env won't be built successfully + collective.init_collective_group(world_size, rank, "nccl", "default") + # initialize and set distributed environment + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..") + log_cuda_info("Worker.setup") + + # Load model + self.tokenizer = get_tokenizer(tokenizer_name=self.model_path) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16 + ) + + shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True) + self.infer_engine = TPInferEngine( + self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len + ) + self.start_dynamic_batching = start_dynamic_batching(self.router_config, self.infer_engine, []) + + return True + + def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> str: + ray_serve_logger.info(f"text: {prompt}") + + results_generator = self.start_dynamic_batching.generate(prompt, sampling_params, request_id) + + final_output = None + for request_output in results_generator: + final_output = request_output + + assert final_output is not None + ray_serve_logger.info(f"Generated text: {final_output}") + return final_output + + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): + self.start_dynamic_batching.add_input(request_id, sampling_params, prompt) + + def abort(self, request_id: str): + self.start_dynamic_batching.abort(request_id) + + def step(self): + self.start_dynamic_batching._step() + + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): + self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt) + + +class Driver: + def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClass): + log_cuda_info("Driver:init") + model_path = engine_config.model + tensor_parallel_size = engine_config.tensor_parallel_size + + self.num_workers = tensor_parallel_size + self.workers = [] + init_rets = [] + + # Just grab a free port on localhost + # NOTE workers in this communication group listen to the same port + available_port = free_port() + + for i in range(self.num_workers): + worker_name = "worker_idx_{}".format(i) + w = Worker.options(name=worker_name).remote( + model_path, + self.num_workers, + engine_config.max_batch_size, + engine_config.max_input_len, + engine_config.max_output_len, + router_config, + ) + self.workers.append(w) + init_rets.append(w.setup.remote(self.num_workers, i, available_port)) + _options = { + "group_name": "default_driver", + "world_size": self.num_workers, + "ranks": [i for i in range(self.num_workers)], + "backend": "nccl", + } + collective.create_collective_group(self.workers, **_options) + _ = ray.get(init_rets) + + # set batch wait delay in seconds and maximum number of sequences in a batch + def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): + results = ray.get([w.generate.remote(request_id, prompt, sampling_params) for w in self.workers]) + text_res = results[0] # get any one of the copies + return text_res + + async def async_generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): + all_outputs = [] + for worker in self.workers: + all_outputs.append(worker.generate.remote(request_id, prompt, sampling_params)) + all_outputs = await asyncio.gather(*all_outputs) + text_res = all_outputs[0] # get any one of the copies + return text_res + + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): + ray.get([w.add_input.remote(request_id, sampling_params, prompt) for w in self.workers]) + + def abort(self, request_id: str): + ray.get([w.abort.remote(request_id) for w in self.workers]) + + def step(self): + ray.get([w._step.remote() for w in self.workers]) + + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): + ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) diff --git a/colossalai/inference/dynamic_batching/ray_init_config.py b/colossalai/inference/dynamic_batching/ray_init_config.py new file mode 100644 index 000000000000..471f07330aec --- /dev/null +++ b/colossalai/inference/dynamic_batching/ray_init_config.py @@ -0,0 +1,58 @@ +import logging + +import yaml +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +class EngineArgsClass(BaseModel): + """Config for Engine""" + + model: str + tensor_parallel_size: int = 2 + max_batch_size: int = 4 + max_input_len: int = 128 + max_output_len: int = 32 + + +class RooterArgsClass(BaseModel): + """Config for Rooter""" + + max_total_token_num: int = 42 + batch_max_tokens: int = 42 + eos_id: int = 0 + disable_log_stats: bool = False + log_stats_interval: int = 10 + model: str + + +class RayInitConfig(BaseModel): + """All-together configs without app router config""" + + engine_config_data: EngineArgsClass + router_config_data: RooterArgsClass + + @classmethod + def from_yaml_path(cls, path: str): + try: + with open(path, "r") as yaml_file: + try: + config = yaml.safe_load(yaml_file) + # serve deployment config + engine_config = config.get("engine_config", {}) + router_config = config.get("router_config", {}) + + return cls( + engine_config_data=engine_config, + router_config_data=router_config, + ) + except yaml.YAMLError as e: + logger.error(f"An Error occurred when parsing yaml: {e}") + raise + except FileNotFoundError: + logger.error(f"The file '{path}' does not exist!") + raise + except OSError as e: + logger.error(f"An Error occurred: {e}") + raise diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 72f77406789f..30717a915e3b 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,6 +1,7 @@ import time from typing import List +from .dynamic_batching.get_tokenizer import get_tokenizer from .dynamic_batching.infer_batch import InferBatch from .dynamic_batching.io_struct import Batch, Req from .dynamic_batching.req_queue import ReqQueue @@ -16,6 +17,7 @@ def __init__( max_total_token_num, batch_max_tokens, eos_id, + model, log_stats=True, log_stats_interval=10, running_batch: Batch = None, @@ -27,6 +29,7 @@ def __init__( batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine eos_id : The end token of a seq + model: the model weight dir path, the app will load config, weights and tokenizer from this dir log_stats : whether to log stats log_stats_interval : log stats interval running_batch : running batch @@ -42,18 +45,32 @@ def __init__( self.eos_id = eos_id self.has_wait_tokens = 0 self.max_wait_tokens = 10 - + self.model = model + self.stats_tool = Stats(log_stats, log_stats_interval) self.mem_usage_interval = log_stats_interval * 2 + self.tokenizer = get_tokenizer(tokenizer_name=self.model) - def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str): + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str): """ Add new request to req queue, during initialization all requests are held in waiting list. """ - req = Req(request_id, prompt_ids, sampling_params) + req = Req(request_id, prompt_ids, sampling_params, prompts) self.req_queue.append(req) return + def add_input(self, request_id, sampling_params, prompts): + """ + Encode and Add new input to req queue. support one sequence input for now. + """ + prompt_ids = self.tokenizer.encode(prompts) + prompt_len = len(prompt_ids) + if prompt_len > self.engine.max_input_len: + raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}") + sampling_params.stop_sentences_to_token_ids(self.tokenizer) + self.add_req(prompt_ids, sampling_params, request_id, prompts) + return + def abort(self, request_id): if self.running_batch is not None: for req in self.running_batch.reqs: @@ -71,8 +88,9 @@ def loop_for_fwd(self): The main loop for a dynamic batching process. """ counter_count = 0 + # self.running_batch is not None or self.req_queue.waiting_req_list while self.running_batch is not None or self.req_queue.waiting_req_list: - self._step() + yield from self._step() counter_count += 1 if self.running_batch is not None: if counter_count % self.mem_usage_interval == 0: @@ -97,14 +115,14 @@ def _step(self): if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch - self._prefill_batch(self.running_batch) + yield from self._prefill_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) - self._decode_batch(self.running_batch) + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 return @@ -112,14 +130,15 @@ def _step(self): new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) - self._prefill_batch(new_mini_batch) + yield from self._prefill_batch(new_mini_batch) if not new_mini_batch.is_clear(): self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 + else: self.stats_tool.count_output_tokens(self.running_batch) - self._decode_batch(self.running_batch) + yield from self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 @@ -158,7 +177,8 @@ def _prefill_batch(self, batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - self._handle_finish_req(batch, has_new_finished_req) + yield from self._handle_finish_req(batch, has_new_finished_req) + # delete finished reqs def _decode_batch(self, batch: Batch): @@ -169,7 +189,7 @@ def _decode_batch(self, batch: Batch): req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - self._handle_finish_req(batch, has_new_finished_req) + yield from self._handle_finish_req(batch, has_new_finished_req) def _filter_batch(self, batch: Batch): batch_id = batch.batch_id @@ -201,11 +221,12 @@ def _remove_batch(self, batch): def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: - batch.filter_finished() + finished_reqs = batch.filter_finished() if batch.is_clear(): self._remove_batch(batch) else: self._filter_batch(batch) + yield from self._output_process(finished_reqs) def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): @@ -218,26 +239,41 @@ def _add_token_id_to_req(self, batch: Batch, req_ans): req.output_metadata_list.append(new_gen_metadata) return + def _output_process(self, finished_reqs: List[Req]): + """ + Process the output of a batch. + """ + for req in finished_reqs: + output = self.tokenizer.decode(req.output_ids) + yield req.prompts + output + def clean_up(self): # this logic should be implemented in the future. pass + def generate(self, prompts, sampling_params, request_id): + """ + Generate the output of a request. + """ + self.add_input(request_id, sampling_params, prompts) + return self.loop_for_fwd() + def start_dynamic_batching(args, tp_engine, waiting_req_list): - # try: - batch_manager = DynamicBatchManager( - tp_engine=tp_engine, - max_total_token_num=args.max_total_token_num, - batch_max_tokens=args.batch_max_tokens, - eos_id=args.eos_id, - log_stats=not args.disable_log_stats, - log_stats_interval=args.log_stats_interval, - waiting_req_list=waiting_req_list, - ) - - # except Exception: - # batch_manager.clean_up() - # raise - - batch_manager.loop_for_fwd() - return + try: + batch_manager = DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + model=args.model, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + except Exception: + batch_manager.clean_up() + raise + + return batch_manager diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 467f83610eb0..f54b13c7e43c 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -18,4 +18,6 @@ SentencePiece ninja flash_attn==2.0.5 datasets +pydantic +ray #auto-gptq now not support torch1.12 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 9aa5f2822e40..8a4b0f1a0ffd 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -11,3 +11,5 @@ ninja torch>=1.12 safetensors einops +pydantic +ray diff --git a/tests/test_infer/test_dynamic_batching/config.yaml b/tests/test_infer/test_dynamic_batching/config.yaml new file mode 100644 index 000000000000..c31ae8c5fadb --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/config.yaml @@ -0,0 +1,15 @@ +engine_config: + model: MODEL_PATH + tensor_parallel_size: 2 + max_batch_size: 4 + max_input_len: 128 + max_output_len: 32 +# config for app router deployment +# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig? +router_config: + max_total_token_num: 42 + batch_max_tokens: 42 + eos_id: 0 + disable_log_stats: False + log_stats_interval: 10 + model: MODEL_PATH diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py new file mode 100644 index 000000000000..4cf9881f41dc --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -0,0 +1,56 @@ +import asyncio +import os +import uuid + +from colossalai.inference.dynamic_batching.ray_dist_init import Driver +from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +import colossalai +import pytest +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +PATH = "config.yaml" + +def run_ray_dist(path: str): + print(f"Using yaml file {path}") + if not os.path.exists(path): + raise FileNotFoundError(f"Invalid yaml file path {path}") + config = RayInitConfig.from_yaml_path(path) + router_config = config.router_config_data + engine_config = config.engine_config_data + model = engine_config.model + if model is None or not os.path.exists(model): + return + driver = Driver(router_config=router_config, engine_config=engine_config) + prompt = "Introduce some landmarks in Beijing" + + request_id = str(uuid.uuid4().hex) + + sampling_params = SamplingParams() + + async def get_result(request_id, prompt, sampling_params): + return await driver.async_generate(request_id, prompt, sampling_params) + + for test_async in [True, False]: + if test_async: + print("test_async: ", test_async) + result = asyncio.run(get_result(request_id, prompt, sampling_params)) + print("result: ", result) + else: + print("test_async: ", test_async) + result = driver.generate(request_id, prompt, sampling_params) + print("result: ", result) + +def check_ray_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_ray_dist(PATH) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_ray_dist(): + spawn(check_ray_dist, 1) + +if __name__ == "__main__": + test_ray_dist() From c76fd687ab92f9d29400573e5fb6dea0f272d1de Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Tue, 17 Oct 2023 12:20:35 +0800 Subject: [PATCH 20/32] support dynamic batch for bloom model and is_running function --- .../dynamic_batching/ray_dist_init.py | 7 ++++ colossalai/inference/manager.py | 4 ++- .../tensor_parallel/modeling/bloom.py | 35 ++++++------------- .../test_dynamic_batching/config.yaml | 4 +-- .../test_dynamic_batching/test_ray_dist.py | 7 ++++ 5 files changed, 30 insertions(+), 27 deletions(-) diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index a40a00e2666c..10bbfe250e4e 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -100,6 +100,9 @@ def step(self): def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt) + + def is_running(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): + return self.start_dynamic_batching.is_running() class Driver: @@ -162,3 +165,7 @@ def step(self): def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) + + def is_running(self): + results = ray.get([w.is_running.remote() for w in self.workers]) + return any(results) diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 30717a915e3b..bd33837dc451 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -257,7 +257,9 @@ def generate(self, prompts, sampling_params, request_id): """ self.add_input(request_id, sampling_params, prompts) return self.loop_for_fwd() - + + def is_running(self): + return self.running_batch is not None or self.req_queue.waiting_req_list def start_dynamic_batching(args, tp_engine, waiting_req_list): try: diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index 27a26caabefa..c10f7e620852 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -133,17 +133,11 @@ def bloom_model_forward( assert hasattr(self, "infer_state") infer_state = self.infer_state - # Compute alibi tensor: check build_alibi_tensor documentation - seq_length_with_past = seq_length - past_key_values_length = 0 - # if self.cache_manager.past_key_values_length > 0: - if infer_state.cache_manager.past_key_values_length > 0: - # update the past key values length in cache manager, - # NOTE use BatchInferState.past_key_values_length instead the one in cache manager - past_key_values_length = infer_state.cache_manager.past_key_values_length - seq_length_with_past = seq_length_with_past + past_key_values_length - # infer_state.cache_manager = self.cache_manager + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 if use_cache and seq_length != 1: # prefill stage @@ -160,21 +154,19 @@ def bloom_model_forward( infer_state.decode_mem_index = alloc_mem[0] infer_state.decode_mem_start = alloc_mem[1] infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index else: print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) + print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}") infer_state.decode_is_contiguous = False alloc_mem = infer_state.cache_manager.alloc(batch_size) infer_state.decode_mem_index = alloc_mem # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + attention_mask = torch.ones((batch_size, infer_state.max_len_in_batch), device=hidden_states.device) else: attention_mask = attention_mask.to(hidden_states.device) @@ -195,6 +187,7 @@ def bloom_model_forward( past_key_values_length=past_key_values_length, ) + infer_state.decode_layer_id = 0 for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -228,6 +221,7 @@ def custom_forward(*inputs): infer_state=infer_state, ) + infer_state.decode_layer_id += 1 hidden_states = outputs[0] if use_cache is True: presents = presents + (outputs[1],) @@ -247,7 +241,6 @@ def custom_forward(*inputs): # and update these information in engine.generate after model foward called infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 - infer_state.decode_layer_id = 0 if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) @@ -453,9 +446,6 @@ def bloom_attention_forward( mem_manager = infer_state.cache_manager layer_id = infer_state.decode_layer_id - if layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_length # += 1 - if infer_state.is_context_stage: # context process max_input_len = q_length @@ -506,15 +496,12 @@ def bloom_attention_forward( b_loc, b_start_loc, b_seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, alibi, ) context_layer = output.view(batch_size, q_length, H * D_HEAD) - # update layer id - infer_state.decode_layer_id += 1 - # NOTE: always set present as none for now, instead of returning past key value to the next decoding, # we create the past key value pair from the cache manager present = None diff --git a/tests/test_infer/test_dynamic_batching/config.yaml b/tests/test_infer/test_dynamic_batching/config.yaml index c31ae8c5fadb..6bd26a7f9fc7 100644 --- a/tests/test_infer/test_dynamic_batching/config.yaml +++ b/tests/test_infer/test_dynamic_batching/config.yaml @@ -1,5 +1,5 @@ engine_config: - model: MODEL_PATH + model: /home/lccd/share/model_data/models--bigscience--bloom-560m/snapshots/4f42c91d806a19ae1a46af6c3fb5f4990d884cd6 tensor_parallel_size: 2 max_batch_size: 4 max_input_len: 128 @@ -12,4 +12,4 @@ router_config: eos_id: 0 disable_log_stats: False log_stats_interval: 10 - model: MODEL_PATH + model: /home/lccd/share/model_data/models--bigscience--bloom-560m/snapshots/4f42c91d806a19ae1a46af6c3fb5f4990d884cd6 diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index 4cf9881f41dc..0eea9ef16345 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -35,11 +35,18 @@ async def get_result(request_id, prompt, sampling_params): if test_async: print("test_async: ", test_async) result = asyncio.run(get_result(request_id, prompt, sampling_params)) + assert result is not None print("result: ", result) else: print("test_async: ", test_async) result = driver.generate(request_id, prompt, sampling_params) + assert result is not None print("result: ", result) + + is_running = None + is_running = driver.is_running() + assert is_running is not None + print("is_running: ", is_running) def check_ray_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") From 4ea9fbec5c98c33f204b9aaa442403a4e58affe1 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Thu, 19 Oct 2023 10:04:09 +0800 Subject: [PATCH 21/32] [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> --- colossalai/inference/async_engine.py | 133 ++++++++++++++++ colossalai/inference/async_manager.py | 150 ++++++++++++++++++ .../dynamic_batching/get_tokenizer.py | 8 +- .../inference/dynamic_batching/infer_batch.py | 16 +- .../inference/dynamic_batching/io_struct.py | 55 ++++--- .../dynamic_batching/ray_dist_init.py | 48 ++---- .../inference/dynamic_batching/req_queue.py | 4 +- .../dynamic_batching/sampling_params.py | 17 +- .../inference/dynamic_batching/stats.py | 2 + colossalai/inference/manager.py | 11 +- .../test_async_engine.py | 60 +++++++ .../test_dynamic_batching_manager.py | 1 + .../test_dynamic_batching/test_ray_dist.py | 18 ++- 13 files changed, 441 insertions(+), 82 deletions(-) create mode 100644 colossalai/inference/async_engine.py create mode 100644 colossalai/inference/async_manager.py create mode 100644 tests/test_infer/test_dynamic_batching/test_async_engine.py diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py new file mode 100644 index 000000000000..a58dde01d250 --- /dev/null +++ b/colossalai/inference/async_engine.py @@ -0,0 +1,133 @@ +import asyncio + +from colossalai.inference.dynamic_batching.ray_dist_init import Driver + +from .dynamic_batching.io_struct import RequestOutput +from .dynamic_batching.sampling_params import SamplingParams + + +class RequestTracker: + """ + A class for trace down all the requests, abstraction for async + """ + + def __init__(self) -> None: + self._requests: asyncio.Queue[str] = asyncio.Queue() + self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue() + self.new_requests_event = None + + def __contains__(self, item): + return item in self._requests + + def init_event(self): + self.new_requests_event = asyncio.Event() + + def add_request(self, request_id: str): + """Add a request to be sent to the engine on the next background + loop iteration.""" + self._requests.put_nowait(request_id) + self.new_requests_event.set() # NOTE: we may find a better way to clear this event + + def add_stop(self): + """ + Add a StopIteration flag to stop async generator. + """ + self._finished_requests.put_nowait(StopIteration) + self.new_requests_event.clear() + + def process_request_output(self, request_output: RequestOutput) -> None: + """Process a request output from the engine.""" + self._finished_requests.put_nowait(request_output) + + async def wait_for_new_requests(self): + await self.new_requests_event.wait() + + def __aiter__(self): + return self + + async def __anext__(self) -> RequestOutput: + result = await self._finished_requests.get() + # print("result of ", result) + if result is StopIteration: + raise StopAsyncIteration + return result + + +class Async_Engine: + + """ + Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager + Background loop: inference reqs in waiting list (Listen) + Request Tracker: manage incoming requests and restore finished ones + Generate: exposed func for add new input and return finished ones + """ + + def __init__( + self, + router_config, + engine_config, + start_engine_loop: bool = True, + ) -> None: + self.driver = Driver(router_config=router_config, engine_config=engine_config) + self.background_loop = None + self.start_engine_loop = start_engine_loop + self._request_tracker = RequestTracker() + + def _step(self): + """ + Logic for handling requests + """ + request_outputs = self.driver.step() + if request_outputs is not None: + for request_output in request_outputs: + self._request_tracker.process_request_output(request_output) + self._request_tracker.add_stop() + + def abort(self, request_id: str): + self.driver.abort(request_id) + + def _has_requests_in_progress(self): + return self.driver.is_running() + + async def run_loop_fwd(self): + has_requests_in_progress = self._has_requests_in_progress() + while True: + if not has_requests_in_progress: + await self._request_tracker.wait_for_new_requests() + self._step() + await asyncio.sleep(0) + + @property + def is_running(self): + return self.background_loop is not None and not self.background_loop.done() + + def start_background_loop(self): + if self.is_running: + raise RuntimeError("Background loop is already running.") + + self._request_tracker.init_event() + + self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd()) + self.background_loop = asyncio.shield(self.background_loop_unshielded) + + async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams): + self.driver.add_input(request_id, prompt, sampling_params) + self._request_tracker.add_request(request_id) + + async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): + """ + The only exposed func, adding new request and return a async generator that yields the existing results. + """ + try: + if not self.is_running: + self.start_background_loop() + + await self.add_request(request_id, prompt, sampling_params) + + async for request_output in self._request_tracker: + yield request_output + + except (Exception, asyncio.CancelledError) as e: + # If there is an exception or coroutine is cancelled, abort the request. + self.abort_request(request_id) + raise e diff --git a/colossalai/inference/async_manager.py b/colossalai/inference/async_manager.py new file mode 100644 index 000000000000..78d11b1caa44 --- /dev/null +++ b/colossalai/inference/async_manager.py @@ -0,0 +1,150 @@ +from typing import List + +from .dynamic_batching.io_struct import Batch, Req, RequestOutput +from .manager import DynamicBatchManager +from .tensor_parallel import TPInferEngine + + +class Async_DynamicBatchManager(DynamicBatchManager): + def __init__( + self, + tp_engine: TPInferEngine, + max_total_token_num, + batch_max_tokens, + eos_id, + model, + log_stats=True, + log_stats_interval=10, + running_batch: Batch = None, + waiting_req_list: List = [], + ): + """ + Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager + max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) + batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests + running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine + eos_id : The end token of a seq + model: the model weight dir path, the app will load config, weights and tokenizer from this dir + log_stats : whether to log stats + log_stats_interval : log stats interval + running_batch : running batch + waiting_req_list : list of waiting requests, initialized before dynamic batch manager + """ + super().__init__( + tp_engine, + max_total_token_num, + batch_max_tokens, + eos_id, + model, + log_stats, + log_stats_interval, + running_batch, + waiting_req_list, + ) + + def _step(self): + """ + Logic for handling requests + """ + has_new_finished = False + if self.running_batch is None: + new_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_batch is not None: + self.stats_tool.count_prompt_tokens(new_batch) + self.running_batch = new_batch + has_new_finished, outputs = self._prefill_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens = 0 + + else: + if self.has_wait_tokens < self.max_wait_tokens: + self.stats_tool.count_output_tokens(self.running_batch) + has_new_finished, outputs = self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + + else: + new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_mini_batch is not None: + self.stats_tool.count_prompt_tokens(new_mini_batch) + has_new_finished, outputs = self._prefill_batch(new_mini_batch) + if not new_mini_batch.is_clear(): + self._merge_batch(self.running_batch, new_mini_batch) + self.running_batch.merge(new_mini_batch) + self.has_wait_tokens = 0 + + else: + self.stats_tool.count_output_tokens(self.running_batch) + has_new_finished, outputs = self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + + if has_new_finished: + return outputs + return None + + def _prefill_batch(self, batch): + """ + For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. + """ + self._init_batch(batch) + + # TODO: figure out if cache and batch id is needed + ans = self.engine._prefill_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id) + outputs = self._handle_finish_req(batch, has_new_finished_req) + return has_new_finished_req, outputs + # delete finished reqs + + def _decode_batch(self, batch: Batch): + """ + Decoding process + """ + ans = self.engine._decode_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id) + outputs = self._handle_finish_req(batch, has_new_finished_req) + return has_new_finished_req, outputs + + def _handle_finish_req(self, batch: Batch, has_new_finished_req): + if has_new_finished_req: + finished_reqs = batch.filter_finished() + if batch.is_clear(): + self._remove_batch(batch) + else: + self._filter_batch(batch) + return self._output_process(finished_reqs) + return None + + def _output_process(self, finished_reqs: List[Req]): + """ + Process the output of a batch. + """ + outputs = [] + for req in finished_reqs: + output = self.tokenizer.decode(req.output_ids) + outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output)) + return outputs + + +def start_dynamic_batching(args, tp_engine, waiting_req_list): + try: + batch_manager = Async_DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + model=args.model, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + except Exception: + batch_manager.clean_up() + raise + + return batch_manager diff --git a/colossalai/inference/dynamic_batching/get_tokenizer.py b/colossalai/inference/dynamic_batching/get_tokenizer.py index af1f26848b3a..94aa3f24393f 100644 --- a/colossalai/inference/dynamic_batching/get_tokenizer.py +++ b/colossalai/inference/dynamic_batching/get_tokenizer.py @@ -1,6 +1,12 @@ +""" +Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue. + +license: MIT, see LICENSE for more details. +""" + from transformers import AutoTokenizer -_FAST_LLAMA_TOKENIZER = "/home/lccd/share/llama-tokenizer" +_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" def get_tokenizer( diff --git a/colossalai/inference/dynamic_batching/infer_batch.py b/colossalai/inference/dynamic_batching/infer_batch.py index 826272db3e11..112784c15f84 100644 --- a/colossalai/inference/dynamic_batching/infer_batch.py +++ b/colossalai/inference/dynamic_batching/infer_batch.py @@ -1,15 +1,16 @@ +# Adapted from https://github.com/ModelTC/lightllm + import collections from dataclasses import dataclass -from typing import Dict, List , Tuple +from typing import Dict, List, Tuple import numpy as np import torch from colossalai.inference.tensor_parallel import MemoryManager -# make batch infer state an attr of InferBatch - +# make batch infer state an attr of InferBatch class InferSamplingParams: def __init__( self, @@ -65,7 +66,7 @@ def init_batch( cache_manager: MemoryManager, vocab_size: int, max_total_len: int, - ) -> 'InferBatch': + ) -> "InferBatch": input_lengths = [] all_input_ids = [] requests_idx_mapping = {} @@ -76,7 +77,7 @@ def init_batch( nopad_total_token_num = 0 nopad_max_len_in_batch = 0 nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda") - # to avoid memory leak , we pre-allocate 12 more space for each batch. + # to avoid memory leak , we pre-allocate 12 more space for each batch. nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda") for i, r in enumerate(requests): # request id -> idx in list mapping @@ -142,10 +143,9 @@ def free_self(self) -> None: ) remove_index = torch.cat(remove_index, dim=-1) self.cache_manager.free(remove_index) - @torch.no_grad() - def filter(self, request_ids: List[int]) -> 'InferBatch': + def filter(self, request_ids: List[int]) -> "InferBatch": """ Filter finished batch and return a new InferBatch with left ones. """ @@ -226,7 +226,7 @@ def filter(self, request_ids: List[int]) -> 'InferBatch': @classmethod @torch.no_grad() - def merge(cls, batch1, batch2) -> 'InferBatch': + def merge(cls, batch1, batch2) -> "InferBatch": """ Return megerd new InferBatch """ diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 9faaad6f111e..a75eb8007a02 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -1,10 +1,12 @@ +# Adapted from https://github.com/ModelTC/lightllm + from typing import Dict, List, Tuple from .sampling_params import SamplingParams class Req: - def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str): + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""): self.request_id = request_id self.prompt_ids = prompt_ids self.input_len = len(prompt_ids) @@ -49,26 +51,6 @@ def __repr__(self): return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, " -class ReqDetokenizationState: - def __init__( - self, - request_id: str, - prompt_ids: List[int], - max_output_len: int, - ignore_eos: bool, - ) -> None: - self.request_id = request_id - self.prompt_ids = prompt_ids - self.output_ids = [] - self.output_tokens = [] - self.output_str = "" - self.sub_texts = [] - self.current_sub_text = [] - self.max_output_len = max_output_len - self.ignore_eos = ignore_eos - self.gen_metadata = {} - - class Batch: def __init__(self, batch_id, reqs: List[Req]): self.batch_id = batch_id @@ -156,3 +138,34 @@ def __init__(self): class AbortReq: def __init__(self, req_id): self.req_id = req_id + + +class RequestOutput: + """The output data of a request to the LLM. + + Args: + request_id: The unique ID of the request. + prompt: The prompt string of the request. + prompt_token_ids: The token IDs of the prompt. + outputs: The output sequences of the request. + """ + + def __init__( + self, + request_id: str, + prompt: str, + prompt_token_ids: List[int], + outputs, + ) -> None: + self.request_id = request_id + self.prompt = prompt + self.prompt_token_ids = prompt_token_ids + self.outputs = outputs + + def __repr__(self) -> str: + return ( + f"RequestOutput(request_id={self.request_id}, " + f"prompt={self.prompt!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"outputs={self.outputs}, " + ) diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index 70cc21436456..7639633eaa79 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -1,4 +1,3 @@ -import asyncio import logging import os from typing import List @@ -9,10 +8,11 @@ from transformers import AutoModelForCausalLM import colossalai +from colossalai.inference.async_manager import start_dynamic_batching from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer +from colossalai.inference.dynamic_batching.io_struct import RequestOutput from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -from colossalai.inference.manager import start_dynamic_batching from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.shardformer import ShardConfig from colossalai.testing import free_port @@ -76,31 +76,25 @@ def setup(self, world_size, rank, port): return True - def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> str: - ray_serve_logger.info(f"text: {prompt}") + # def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> List[str]: + # ray_serve_logger.info(f"text: {prompt}") - results_generator = self.start_dynamic_batching.generate(prompt, sampling_params, request_id) + # final_outputs = self.start_dynamic_batching.generate(prompt, sampling_params, request_id) - final_output = None - for request_output in results_generator: - final_output = request_output - - assert final_output is not None - ray_serve_logger.info(f"Generated text: {final_output}") - return final_output + # return final_outputs def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): - self.start_dynamic_batching.add_input(request_id, sampling_params, prompt) + self.start_dynamic_batching.add_input(request_id, prompt, sampling_params) def abort(self, request_id: str): self.start_dynamic_batching.abort(request_id) - def step(self): - self.start_dynamic_batching._step() + def step(self) -> List[RequestOutput]: + return self.start_dynamic_batching._step() def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt) - + def is_running(self): return self.start_dynamic_batching.is_running() @@ -140,32 +134,20 @@ def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClas collective.create_collective_group(self.workers, **_options) _ = ray.get(init_rets) - # set batch wait delay in seconds and maximum number of sequences in a batch - def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): - results = ray.get([w.generate.remote(request_id, prompt, sampling_params) for w in self.workers]) - text_res = results[0] # get any one of the copies - return text_res - - async def async_generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): - all_outputs = [] - for worker in self.workers: - all_outputs.append(worker.generate.remote(request_id, prompt, sampling_params)) - all_outputs = await asyncio.gather(*all_outputs) - text_res = all_outputs[0] # get any one of the copies - return text_res - def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): - ray.get([w.add_input.remote(request_id, sampling_params, prompt) for w in self.workers]) + ray.get([w.add_input.remote(request_id, prompt, sampling_params) for w in self.workers]) def abort(self, request_id: str): ray.get([w.abort.remote(request_id) for w in self.workers]) def step(self): - ray.get([w._step.remote() for w in self.workers]) + results = ray.get([w.step.remote() for w in self.workers]) + outputs = results[0] # get any one of the copies + return outputs def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) - + def is_running(self): results = ray.get([w.is_running.remote() for w in self.workers]) return any(results) diff --git a/colossalai/inference/dynamic_batching/req_queue.py b/colossalai/inference/dynamic_batching/req_queue.py index d9e9b6269cc4..0de43bd1a21f 100644 --- a/colossalai/inference/dynamic_batching/req_queue.py +++ b/colossalai/inference/dynamic_batching/req_queue.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/ModelTC/lightllm + import uuid from typing import List @@ -41,7 +43,7 @@ def _can_add_new_req(self, req): need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() # NOTE: change here < to <= return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size - + def generate_new_batch(self, current_batch: Batch = None): if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size: return None diff --git a/colossalai/inference/dynamic_batching/sampling_params.py b/colossalai/inference/dynamic_batching/sampling_params.py index 9a0ace4111dd..2028da907259 100644 --- a/colossalai/inference/dynamic_batching/sampling_params.py +++ b/colossalai/inference/dynamic_batching/sampling_params.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/ModelTC/lightllm + """Sampling parameters for text generation.""" from typing import List, Optional, Union @@ -5,7 +7,6 @@ class SamplingParams: - def __init__( self, do_sample: bool = False, @@ -13,10 +14,10 @@ def __init__( frequency_penalty: float = 0.0, temperature: float = 1.0, top_p: float = 1.0, - top_k: int = -1, # -1 is for all + top_k: int = -1, # -1 is for all ignore_eos: bool = False, max_new_tokens: int = 16, - stop_sequences: Optional[Union[str, List[str]]] = None # conditions to stop generation + stop_sequences: Optional[Union[str, List[str]]] = None, # conditions to stop generation ) -> None: self.do_sample = do_sample self.presence_penalty = presence_penalty @@ -31,11 +32,13 @@ def __init__( self.temperature = 1.0 self.top_p = 1.0 self.top_k = 1 - if self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS: # temperature is too slow, change to greedy search + if ( + self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS + ): # temperature is too slow, change to greedy search self.temperature = 1.0 self.top_k = 1 return - + def verify(self): if self.presence_penalty < 0.0: raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}") @@ -60,13 +63,13 @@ def stop_sentences_to_token_ids(self, tokenizer): new_stop_sequences = [] for stop_str in self.stop_sequences: stop_str_ids = tokenizer.encode(stop_str) - if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id + if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id stop_str_ids = stop_str_ids[1:] if len(stop_str_ids) > 0: new_stop_sequences.append(stop_str_ids) self.stop_sequences = new_stop_sequences return - + def to_dict(self): ret = {} ret["do_sample"] = self.do_sample diff --git a/colossalai/inference/dynamic_batching/stats.py b/colossalai/inference/dynamic_batching/stats.py index 6d34183f47c4..524072861a3f 100644 --- a/colossalai/inference/dynamic_batching/stats.py +++ b/colossalai/inference/dynamic_batching/stats.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/ModelTC/lightllm + import time diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index bd33837dc451..42ff8bf1e9ef 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/ModelTC/lightllm + import time from typing import List @@ -51,7 +53,7 @@ def __init__( self.mem_usage_interval = log_stats_interval * 2 self.tokenizer = get_tokenizer(tokenizer_name=self.model) - def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str): + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str = ""): """ Add new request to req queue, during initialization all requests are held in waiting list. """ @@ -59,7 +61,7 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques self.req_queue.append(req) return - def add_input(self, request_id, sampling_params, prompts): + def add_input(self, request_id, prompts, sampling_params): """ Encode and Add new input to req queue. support one sequence input for now. """ @@ -257,9 +259,10 @@ def generate(self, prompts, sampling_params, request_id): """ self.add_input(request_id, sampling_params, prompts) return self.loop_for_fwd() - + def is_running(self): - return self.running_batch is not None or self.req_queue.waiting_req_list + return self.running_batch is not None or self.req_queue.waiting_req_list + def start_dynamic_batching(args, tp_engine, waiting_req_list): try: diff --git a/tests/test_infer/test_dynamic_batching/test_async_engine.py b/tests/test_infer/test_dynamic_batching/test_async_engine.py new file mode 100644 index 000000000000..148d325a1d9a --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_async_engine.py @@ -0,0 +1,60 @@ +import asyncio +import os +import uuid + +import pytest + +import colossalai +from colossalai.inference.async_engine import Async_Engine +from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +PATH = "config.yaml" + + +def run_async_engine(path: str): + if not os.path.exists(path): + return + + config = RayInitConfig.from_yaml_path(path) + engine_config = config.engine_config_data + model = engine_config.model + if model is None or not os.path.exists(model): + return + + prompt = "Introduce some landmarks in Beijing" + sampling_params = SamplingParams() + asyncio.run(asy_for_loop_test(config, prompt, sampling_params)) + + +async def get_result(engine, prompt, sampling_params): + request_id = str(uuid.uuid4().hex) + results = engine.generate(request_id, prompt, sampling_params) + async for result in results: + assert result is not None + + +async def asy_for_loop_test(config, prompt, sampling_params): + router_config = config.router_config_data + engine_config = config.engine_config_data + engine = Async_Engine(router_config=router_config, engine_config=engine_config) + for i in range(10): + print("in for loop", i) + await get_result(engine, prompt, sampling_params) + + +def check_async_engine(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_async_engine(PATH) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_async_engine(): + spawn(check_async_engine, 1) + + +if __name__ == "__main__": + test_async_engine() diff --git a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py index 124f1f478b00..588922b5a58f 100644 --- a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py +++ b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py @@ -45,6 +45,7 @@ def run(): log_stats=False, log_stats_interval=10, waiting_req_list=waiting_list, + model="llama", ) before_add = len(dynamic_batch_manager.req_queue) diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index 0eea9ef16345..5c84b39d8f8e 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -2,19 +2,21 @@ import os import uuid +import pytest + +import colossalai from colossalai.inference.dynamic_batching.ray_dist_init import Driver from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -import colossalai -import pytest from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn PATH = "config.yaml" + def run_ray_dist(path: str): print(f"Using yaml file {path}") if not os.path.exists(path): - raise FileNotFoundError(f"Invalid yaml file path {path}") + return config = RayInitConfig.from_yaml_path(path) router_config = config.router_config_data engine_config = config.engine_config_data @@ -25,8 +27,8 @@ def run_ray_dist(path: str): prompt = "Introduce some landmarks in Beijing" request_id = str(uuid.uuid4().hex) - sampling_params = SamplingParams() + print("sampling_params: ", sampling_params) async def get_result(request_id, prompt, sampling_params): return await driver.async_generate(request_id, prompt, sampling_params) @@ -35,19 +37,20 @@ async def get_result(request_id, prompt, sampling_params): if test_async: print("test_async: ", test_async) result = asyncio.run(get_result(request_id, prompt, sampling_params)) - assert result is not None + assert result is not None print("result: ", result) else: print("test_async: ", test_async) result = driver.generate(request_id, prompt, sampling_params) assert result is not None print("result: ", result) - + is_running = None is_running = driver.is_running() - assert is_running is not None + assert is_running is not None print("is_running: ", is_running) + def check_ray_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_ray_dist(PATH) @@ -59,5 +62,6 @@ def check_ray_dist(rank, world_size, port): def test_ray_dist(): spawn(check_ray_dist, 1) + if __name__ == "__main__": test_ray_dist() From 3f6af12428396d343a7472b1914470b2b2a41688 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Thu, 19 Oct 2023 17:51:05 +0800 Subject: [PATCH 22/32] add assertion for config (#4947) --- colossalai/inference/async_engine.py | 3 ++- colossalai/inference/manager.py | 7 ++++++- colossalai/inference/tensor_parallel/engine.py | 2 -- tests/test_infer/test_dynamic_batching/config.yaml | 6 +++--- .../test_infer/test_dynamic_batching/test_async_engine.py | 3 ++- tests/test_infer/test_dynamic_batching/test_ray_dist.py | 1 - 6 files changed, 13 insertions(+), 9 deletions(-) diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py index a58dde01d250..4515aeab3469 100644 --- a/colossalai/inference/async_engine.py +++ b/colossalai/inference/async_engine.py @@ -79,11 +79,12 @@ def _step(self): """ request_outputs = self.driver.step() if request_outputs is not None: + print("request_outputs: ", request_outputs) for request_output in request_outputs: self._request_tracker.process_request_output(request_output) self._request_tracker.add_stop() - def abort(self, request_id: str): + def abort_request(self, request_id: str): self.driver.abort(request_id) def _has_requests_in_progress(self): diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 42ff8bf1e9ef..18226d78c20c 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -42,7 +42,12 @@ def __init__( running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2 self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list) # all the inputs should be put into req_queue: waiting req list - + assert max_total_token_num >= self.engine.max_batch_size * ( + self.engine.max_input_len + self.engine.max_output_len + ), "max_total_token_num should be greater than max_batch_size * (max_input_len+max_output_len)" + assert ( + batch_max_tokens >= self.engine.max_input_len + self.engine.max_output_len + ), "batch_max_tokens should be greater than (max_input_len+max_output_len)" self.running_batch: Batch = running_batch self.eos_id = eos_id self.has_wait_tokens = 0 diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index f7fb7a825694..a98b96565c50 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -61,7 +61,6 @@ def __init__( self.max_input_len = max_input_len self.max_output_len = max_output_len self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len) - # Constraints relatable with specs of devices and model # This may change into an optional arg in the future assert self.max_batch_size <= 64, "Max batch size exceeds the constraint" @@ -380,7 +379,6 @@ def forward(self, batch_id, is_prefill): Forward is used in Dynamic Batching Manager """ batch = self.cache.pop(batch_id) - if is_prefill: input_ = torch.tensor(batch.all_input_ids).cuda() else: diff --git a/tests/test_infer/test_dynamic_batching/config.yaml b/tests/test_infer/test_dynamic_batching/config.yaml index c31ae8c5fadb..59ec39779335 100644 --- a/tests/test_infer/test_dynamic_batching/config.yaml +++ b/tests/test_infer/test_dynamic_batching/config.yaml @@ -5,10 +5,10 @@ engine_config: max_input_len: 128 max_output_len: 32 # config for app router deployment -# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig? +# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig. router_config: - max_total_token_num: 42 - batch_max_tokens: 42 + max_total_token_num: 640 + batch_max_tokens: 640 eos_id: 0 disable_log_stats: False log_stats_interval: 10 diff --git a/tests/test_infer/test_dynamic_batching/test_async_engine.py b/tests/test_infer/test_dynamic_batching/test_async_engine.py index 148d325a1d9a..6287699d7b3c 100644 --- a/tests/test_infer/test_dynamic_batching/test_async_engine.py +++ b/tests/test_infer/test_dynamic_batching/test_async_engine.py @@ -23,7 +23,7 @@ def run_async_engine(path: str): if model is None or not os.path.exists(model): return - prompt = "Introduce some landmarks in Beijing" + prompt = "Introduce some landmarks in London.\nThe Tower of London is a historic castle on the north bank of the River Thames in central London. It was founded towards the end of 10" sampling_params = SamplingParams() asyncio.run(asy_for_loop_test(config, prompt, sampling_params)) @@ -32,6 +32,7 @@ async def get_result(engine, prompt, sampling_params): request_id = str(uuid.uuid4().hex) results = engine.generate(request_id, prompt, sampling_params) async for result in results: + # print(result) assert result is not None diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index 5c84b39d8f8e..a840407d5867 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -14,7 +14,6 @@ def run_ray_dist(path: str): - print(f"Using yaml file {path}") if not os.path.exists(path): return config = RayInitConfig.from_yaml_path(path) From 4867561c6f71588883eaafef2fc194d36655f385 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Thu, 19 Oct 2023 22:26:08 +0800 Subject: [PATCH 23/32] [Inference] Finish dynamic batching offline test (#4948) * test * fix test --- .../dynamic_batching/ray_dist_init.py | 2 +- colossalai/inference/manager.py | 8 +++---- .../inference/tensor_parallel/engine.py | 2 +- .../tensor_parallel/modeling/llama.py | 3 +-- ...rd.py => test_offline_dynamic_batching.py} | 22 +++++++++++++++---- 5 files changed, 25 insertions(+), 12 deletions(-) rename tests/test_infer/test_dynamic_batching/{test_forward.py => test_offline_dynamic_batching.py} (80%) diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index 7639633eaa79..e3a261f96a86 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -145,7 +145,7 @@ def step(self): outputs = results[0] # get any one of the copies return outputs - def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): + def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompt: str): ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) def is_running(self): diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 18226d78c20c..55a75d5e3925 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -58,7 +58,7 @@ def __init__( self.mem_usage_interval = log_stats_interval * 2 self.tokenizer = get_tokenizer(tokenizer_name=self.model) - def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str = ""): + def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""): """ Add new request to req queue, during initialization all requests are held in waiting list. """ @@ -75,7 +75,7 @@ def add_input(self, request_id, prompts, sampling_params): if prompt_len > self.engine.max_input_len: raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}") sampling_params.stop_sentences_to_token_ids(self.tokenizer) - self.add_req(prompt_ids, sampling_params, request_id, prompts) + self.add_req(request_id, prompt_ids, sampling_params, prompts) return def abort(self, request_id): @@ -258,11 +258,11 @@ def clean_up(self): # this logic should be implemented in the future. pass - def generate(self, prompts, sampling_params, request_id): + def generate(self, request_id, prompts, sampling_params): """ Generate the output of a request. """ - self.add_input(request_id, sampling_params, prompts) + self.add_input(request_id, prompts, sampling_params) return self.loop_for_fwd() def is_running(self): diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index a98b96565c50..e75004d506a3 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -400,8 +400,8 @@ def forward(self, batch_id, is_prefill): model = self.model.model elif isinstance(model, BloomForCausalLM): model = self.model.transformer + setattr(model, "infer_state", infer_state) - output = self.model.forward(input_ids=input_) logits = output.logits # bsz, seq_len, vocab_size diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 958868a0974e..7e6978ad815b 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -76,14 +76,13 @@ def llama_model_forward( batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # NOT READY FOR PRIME TIME # dummy but work, revise it if infer_state.is_context_stage: past_key_values_length = 0 else: past_key_values_length = infer_state.max_len_in_batch - 1 - + # NOTE: differentiate with prefill stage # block_loc require different value-assigning method for two different stage if use_cache and seq_length != 1: diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py similarity index 80% rename from tests/test_infer/test_dynamic_batching/test_forward.py rename to tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py index 63df491e5b52..9925a80b6e77 100644 --- a/tests/test_infer/test_dynamic_batching/test_forward.py +++ b/tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py @@ -1,3 +1,5 @@ +from dataclasses import dataclass + import pytest import torch from packaging import version @@ -5,7 +7,6 @@ from transformers.models.llama.configuration_llama import LlamaConfig import colossalai -from dataclasses import dataclass from colossalai.inference.dynamic_batching.io_struct import Req from colossalai.inference.dynamic_batching.sampling_params import SamplingParams from colossalai.inference.manager import start_dynamic_batching @@ -19,17 +20,26 @@ MAX_OUTPUT_LEN = 16 CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + @dataclass class args: max_total_token_num: int batch_max_tokens: int + model: str eos_id: int disable_log_stats: bool log_stats_interval: int def run(): - arg = args(max_total_token_num=42, batch_max_tokens=42, eos_id=0, disable_log_stats=False, log_stats_interval=10) + arg = args( + max_total_token_num=42, + model="llama", + batch_max_tokens=42, + eos_id=0, + disable_log_stats=False, + log_stats_interval=10, + ) sampling_params = SamplingParams() req1 = Req(0, [0, 0, 10, 6, 8], sampling_params) @@ -43,14 +53,18 @@ def run(): waiting_list.append(req3) waiting_list.append(req4) - llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=30000, hidden_size=1024) model = LlamaForCausalLM(llama_config) model = model.half() shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) + batch_manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) + + ans_gen = batch_manager.generate(request_id=5, prompts="hello", sampling_params=sampling_params) + for result in ans_gen: + assert result is not None def check_dynamic_forward(rank, world_size, port): From d5d2c94e41d9d919f9eb54b3d5219c0252b097ea Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 20 Oct 2023 15:37:51 +0800 Subject: [PATCH 24/32] fix quant --- .../quant/smoothquant/models/base_model.py | 1 - .../quant/smoothquant/models/llama.py | 27 ++++++------------- .../tensor_parallel/modeling/llama.py | 5 ++-- 3 files changed, 10 insertions(+), 23 deletions(-) diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py index 180e6c6e8fa6..ba46e280346e 100644 --- a/colossalai/inference/quant/smoothquant/models/base_model.py +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -87,7 +87,6 @@ def init_batch_state(self, max_output_len=256, **kwargs): batch_infer_state.start_loc = seq_start_indexes.to("cuda") batch_infer_state.block_loc = block_loc batch_infer_state.decode_layer_id = 0 - batch_infer_state.past_key_values_len = 0 batch_infer_state.is_context_stage = True batch_infer_state.set_cache_manager(self.cache_manager) batch_infer_state.cache_manager.free_all() diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index 9c77feeb346e..bd2bce119b25 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -149,12 +149,6 @@ def forward( self.k_rotary_output_scale.item(), ) - # NOTE might want to revise - # need some way to record the length of past key values cache - # since we won't return past_key_value_cache right now - if infer_state.decode_layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_len # seq_len - def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) @@ -229,7 +223,7 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) @@ -590,17 +584,13 @@ def llama_model_forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - seq_length_with_past = seq_length - past_key_values_length = 0 - infer_state = self.infer_state + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 - if past_key_values is not None: - # NOT READY FOR PRIME TIME - # dummy but work, revise it - past_key_values_length = infer_state.cache_manager.past_key_values_length - # past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length + seq_length_with_past = seq_length + past_key_values_length # NOTE: differentiate with prefill stage # block_loc require different value-assigning method for two different stage @@ -621,9 +611,7 @@ def llama_model_forward( infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index else: print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) + print(f" infer_state.cache_manager.max_len_in_batch: {infer_state.max_len_in_batch}") infer_state.decode_is_contiguous = False alloc_mem = infer_state.cache_manager.alloc(batch_size) infer_state.decode_mem_index = alloc_mem @@ -711,6 +699,7 @@ def llama_model_forward( infer_state.is_context_stage = False infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 + infer_state.max_len_in_batch += 1 next_cache = next_decoder_cache if use_cache else None if not return_dict: diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index d2ac4160ef64..a17b901dc7fd 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -88,8 +88,7 @@ def llama_model_forward( batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # NOT READY FOR PRIME TIME - # dummy but work, revise it + if infer_state.is_context_stage: past_key_values_length = 0 else: @@ -122,7 +121,7 @@ def llama_model_forward( infer_state.decode_mem_index = alloc_mem # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.block_loc[:, se - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device From ed86584187ea6a68d07ac1f63336fe417ab71926 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 20 Oct 2023 16:09:10 +0800 Subject: [PATCH 25/32] add default --- colossalai/inference/async_engine.py | 1 - colossalai/inference/async_manager.py | 2 +- colossalai/inference/manager.py | 4 +++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py index 4515aeab3469..d0890ba3e9fc 100644 --- a/colossalai/inference/async_engine.py +++ b/colossalai/inference/async_engine.py @@ -79,7 +79,6 @@ def _step(self): """ request_outputs = self.driver.step() if request_outputs is not None: - print("request_outputs: ", request_outputs) for request_output in request_outputs: self._request_tracker.process_request_output(request_output) self._request_tracker.add_stop() diff --git a/colossalai/inference/async_manager.py b/colossalai/inference/async_manager.py index 9abe465eb319..aedb5721d62c 100644 --- a/colossalai/inference/async_manager.py +++ b/colossalai/inference/async_manager.py @@ -11,8 +11,8 @@ def __init__( tp_engine: TPInferEngine, max_total_token_num, batch_max_tokens, - eos_id, model, + eos_id=None, log_stats=True, log_stats_interval=10, running_batch: Batch = None, diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 0d165066e09a..ff3b061c45fa 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -18,8 +18,8 @@ def __init__( tp_engine: TPInferEngine, max_total_token_num, batch_max_tokens, - eos_id, model, + eos_id=None, log_stats=True, log_stats_interval=10, running_batch: Batch = None, @@ -57,6 +57,8 @@ def __init__( self.stats_tool = Stats(log_stats, log_stats_interval) self.mem_usage_interval = log_stats_interval * 2 self.tokenizer = get_tokenizer(tokenizer_name=self.model) + if self.eos_id is None: + self.eos_id = self.tokenizer.eos_token_id def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""): """ From 4bffb8b5674d835a647516ed93168e48e088bf0c Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 20 Oct 2023 16:39:57 +0800 Subject: [PATCH 26/32] fix --- colossalai/inference/async_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/colossalai/inference/async_manager.py b/colossalai/inference/async_manager.py index aedb5721d62c..98de28970831 100644 --- a/colossalai/inference/async_manager.py +++ b/colossalai/inference/async_manager.py @@ -9,9 +9,9 @@ class Async_DynamicBatchManager(DynamicBatchManager): def __init__( self, tp_engine: TPInferEngine, - max_total_token_num, - batch_max_tokens, - model, + max_total_token_num: int, + batch_max_tokens: int, + model: str, eos_id=None, log_stats=True, log_stats_interval=10, @@ -34,8 +34,8 @@ def __init__( tp_engine, max_total_token_num, batch_max_tokens, - eos_id, model, + eos_id, log_stats, log_stats_interval, running_batch, From 77adc2e30e0d063112720b676cc425f947e3fd42 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 20 Oct 2023 18:24:52 +0800 Subject: [PATCH 27/32] fix some bugs --- colossalai/inference/async_manager.py | 4 ++-- colossalai/inference/dynamic_batching/io_struct.py | 13 ++++--------- .../inference/dynamic_batching/ray_dist_init.py | 2 +- .../inference/dynamic_batching/sampling_params.py | 4 +--- colossalai/inference/manager.py | 4 ++-- tests/test_infer/test_dynamic_batching/config.yaml | 13 ++++++------- .../test_dynamic_batching/test_async_engine.py | 2 +- 7 files changed, 17 insertions(+), 25 deletions(-) diff --git a/colossalai/inference/async_manager.py b/colossalai/inference/async_manager.py index 98de28970831..fa49444a3920 100644 --- a/colossalai/inference/async_manager.py +++ b/colossalai/inference/async_manager.py @@ -93,7 +93,7 @@ def _prefill_batch(self, batch): ans = self.engine._prefill_batch(batch.batch_id) req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) - has_new_finished_req = batch.mark_finished_req(self.eos_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) outputs = self._handle_finish_req(batch, has_new_finished_req) return has_new_finished_req, outputs # delete finished reqs @@ -105,7 +105,7 @@ def _decode_batch(self, batch: Batch): ans = self.engine._decode_batch(batch.batch_id) req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) - has_new_finished_req = batch.mark_finished_req(self.eos_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) outputs = self._handle_finish_req(batch, has_new_finished_req) return has_new_finished_req, outputs diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index a75eb8007a02..fc5ecfe5796b 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -26,14 +26,6 @@ def to_rpc_obj(self): "sampling_param": self.sample_params.to_dict(), } - def to_req_detokenization_state(self): - out = ReqDetokenizationState( - self.request_id, self.prompt_ids, self.max_output_len, self.sample_params.ignore_eos - ) - if self.output_metadata_list: - out.gen_metadata.update(self.output_metadata_list[-1]) - return out - def stop_sequences_matched(self): # should we add stpp sequences to the sample params? if self.sample_params.stop_sequences is not None: @@ -75,12 +67,15 @@ def calcu_used_tokens(self): tokens += req.input_len + len(req.output_ids) return tokens - def mark_finished_req(self, eos_id): + def mark_finished_req(self, eos_id, engine_max_output_len): has_new_finish = False for req in self.reqs: if req.stop_sequences_matched(): req.has_generate_finished = True has_new_finish = True + if len(req.output_ids) >= engine_max_output_len: + req.has_generate_finished = True + has_new_finish = True if req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False: req.has_generate_finished = True has_new_finish = True diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index e3a261f96a86..169c2d1248e9 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -67,7 +67,7 @@ def setup(self, world_size, rank, port): self.model = AutoModelForCausalLM.from_pretrained( self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16 ) - + print(self.max_batch_size, self.max_input_len, self.max_output_len) shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True) self.infer_engine = TPInferEngine( self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len diff --git a/colossalai/inference/dynamic_batching/sampling_params.py b/colossalai/inference/dynamic_batching/sampling_params.py index 2028da907259..184f165dc35f 100644 --- a/colossalai/inference/dynamic_batching/sampling_params.py +++ b/colossalai/inference/dynamic_batching/sampling_params.py @@ -16,7 +16,7 @@ def __init__( top_p: float = 1.0, top_k: int = -1, # -1 is for all ignore_eos: bool = False, - max_new_tokens: int = 16, + max_new_tokens: int = 1024, stop_sequences: Optional[Union[str, List[str]]] = None, # conditions to stop generation ) -> None: self.do_sample = do_sample @@ -80,6 +80,4 @@ def to_dict(self): ret["top_k"] = self.top_k # if self.ignore_eos is not None: # ret["ignore_eos"] = self.ignore_eos - # if self.max_tokens is not None: - # ret["max_tokens"] = self.max_tokens return ret diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index ff3b061c45fa..d59a05ebc2a7 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -185,7 +185,7 @@ def _prefill_batch(self, batch): ans = self.engine._prefill_batch(batch.batch_id) req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) - has_new_finished_req = batch.mark_finished_req(self.eos_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) yield from self._handle_finish_req(batch, has_new_finished_req) # delete finished reqs @@ -197,7 +197,7 @@ def _decode_batch(self, batch: Batch): ans = self.engine._decode_batch(batch.batch_id) req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) - has_new_finished_req = batch.mark_finished_req(self.eos_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) yield from self._handle_finish_req(batch, has_new_finished_req) def _filter_batch(self, batch: Batch): diff --git a/tests/test_infer/test_dynamic_batching/config.yaml b/tests/test_infer/test_dynamic_batching/config.yaml index 59ec39779335..0ac778a3c7b3 100644 --- a/tests/test_infer/test_dynamic_batching/config.yaml +++ b/tests/test_infer/test_dynamic_batching/config.yaml @@ -1,15 +1,14 @@ engine_config: model: MODEL_PATH - tensor_parallel_size: 2 - max_batch_size: 4 - max_input_len: 128 - max_output_len: 32 + tensor_parallel_size: 1 + max_batch_size: 2 + max_input_len: 1024 + max_output_len: 512 # config for app router deployment # Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig. router_config: - max_total_token_num: 640 - batch_max_tokens: 640 - eos_id: 0 + max_total_token_num: 4096 + batch_max_tokens: 4096 disable_log_stats: False log_stats_interval: 10 model: MODEL_PATH diff --git a/tests/test_infer/test_dynamic_batching/test_async_engine.py b/tests/test_infer/test_dynamic_batching/test_async_engine.py index 6287699d7b3c..512aa7430983 100644 --- a/tests/test_infer/test_dynamic_batching/test_async_engine.py +++ b/tests/test_infer/test_dynamic_batching/test_async_engine.py @@ -23,7 +23,7 @@ def run_async_engine(path: str): if model is None or not os.path.exists(model): return - prompt = "Introduce some landmarks in London.\nThe Tower of London is a historic castle on the north bank of the River Thames in central London. It was founded towards the end of 10" + prompt = "Introduce some landmarks in London.\n The Tower of London is a historic castle on the north bank of the River Thames in central London. It was founded towards the end of 10" sampling_params = SamplingParams() asyncio.run(asy_for_loop_test(config, prompt, sampling_params)) From dcb51b4a0b1b07fd366ce20c1e88ec95bfcaa4b2 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 20 Oct 2023 18:25:06 +0800 Subject: [PATCH 28/32] fix some bugs --- colossalai/inference/dynamic_batching/ray_dist_init.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index 169c2d1248e9..70ef489d3b70 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -67,7 +67,6 @@ def setup(self, world_size, rank, port): self.model = AutoModelForCausalLM.from_pretrained( self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16 ) - print(self.max_batch_size, self.max_input_len, self.max_output_len) shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True) self.infer_engine = TPInferEngine( self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len From afae53b57a06f40ccdba1cb4bdb2250b4b601fd9 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 23 Oct 2023 10:32:27 +0800 Subject: [PATCH 29/32] fix --- colossalai/inference/dynamic_batching/sampling_params.py | 2 +- .../test_dynamic_batching/test_dynamic_batching_manager.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/inference/dynamic_batching/sampling_params.py b/colossalai/inference/dynamic_batching/sampling_params.py index 184f165dc35f..033463bc6950 100644 --- a/colossalai/inference/dynamic_batching/sampling_params.py +++ b/colossalai/inference/dynamic_batching/sampling_params.py @@ -16,7 +16,7 @@ def __init__( top_p: float = 1.0, top_k: int = -1, # -1 is for all ignore_eos: bool = False, - max_new_tokens: int = 1024, + max_new_tokens: int = 16, stop_sequences: Optional[Union[str, List[str]]] = None, # conditions to stop generation ) -> None: self.do_sample = do_sample diff --git a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py index 588922b5a58f..1d10ad55deb8 100644 --- a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py +++ b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py @@ -50,12 +50,14 @@ def run(): before_add = len(dynamic_batch_manager.req_queue) # test add req function - dynamic_batch_manager.add_req(req4.prompt_ids, req4.sample_params, req4.request_id) + dynamic_batch_manager.add_req(req4.request_id, req4.prompt_ids, req4.sample_params) assert len(dynamic_batch_manager.req_queue.waiting_req_list) == before_add + 1 # test abort function dynamic_batch_manager.abort(req4.request_id) assert dynamic_batch_manager.req_queue.waiting_req_list[-1].aborted == True + + print(dynamic_batch_manager.req_queue.waiting_req_list) # test filter batch function, loop_for_fwd, _step, _init_batch and _prefill/_decode batch are tested batch = dynamic_batch_manager.req_queue.generate_new_batch() From c4772667d71110f20e066c6d47db40528d5d908a Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 23 Oct 2023 16:05:59 +0800 Subject: [PATCH 30/32] fix bug --- colossalai/inference/dynamic_batching/sampling_params.py | 2 +- colossalai/inference/manager.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/colossalai/inference/dynamic_batching/sampling_params.py b/colossalai/inference/dynamic_batching/sampling_params.py index 033463bc6950..a37a83390021 100644 --- a/colossalai/inference/dynamic_batching/sampling_params.py +++ b/colossalai/inference/dynamic_batching/sampling_params.py @@ -16,7 +16,7 @@ def __init__( top_p: float = 1.0, top_k: int = -1, # -1 is for all ignore_eos: bool = False, - max_new_tokens: int = 16, + max_new_tokens: int = 256, stop_sequences: Optional[Union[str, List[str]]] = None, # conditions to stop generation ) -> None: self.do_sample = do_sample diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index d59a05ebc2a7..371e2b681c56 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -64,6 +64,11 @@ def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: Sampl """ Add new request to req queue, during initialization all requests are held in waiting list. """ + sampling_params.max_new_tokens = ( + self.engine.max_output_len + if sampling_params.max_new_tokens > self.engine.max_output_len + else sampling_params.max_new_tokens + ) req = Req(request_id, prompt_ids, sampling_params, prompts) self.req_queue.append(req) return From f99eba23be75e3d587cc0f73b117fe3057be3705 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 23 Oct 2023 18:28:20 +0800 Subject: [PATCH 31/32] fix bugs --- colossalai/inference/async_manager.py | 2 ++ colossalai/inference/manager.py | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/colossalai/inference/async_manager.py b/colossalai/inference/async_manager.py index fa49444a3920..60440a792f1c 100644 --- a/colossalai/inference/async_manager.py +++ b/colossalai/inference/async_manager.py @@ -12,6 +12,7 @@ def __init__( max_total_token_num: int, batch_max_tokens: int, model: str, + tokenizer=None, eos_id=None, log_stats=True, log_stats_interval=10, @@ -35,6 +36,7 @@ def __init__( max_total_token_num, batch_max_tokens, model, + tokenizer, eos_id, log_stats, log_stats_interval, diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 371e2b681c56..9672a50141a0 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -19,6 +19,7 @@ def __init__( max_total_token_num, batch_max_tokens, model, + tokenizer=None, eos_id=None, log_stats=True, log_stats_interval=10, @@ -56,8 +57,8 @@ def __init__( self.stats_tool = Stats(log_stats, log_stats_interval) self.mem_usage_interval = log_stats_interval * 2 - self.tokenizer = get_tokenizer(tokenizer_name=self.model) - if self.eos_id is None: + self.tokenizer = get_tokenizer(tokenizer_name=self.model) if tokenizer is None else tokenizer + if self.eos_id == None: self.eos_id = self.tokenizer.eos_token_id def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""): From 4c3ea404683bc9f2335a9a245b83c92c157bf61d Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 27 Oct 2023 10:14:10 +0800 Subject: [PATCH 32/32] reset param --- .../test_dynamic_batching_manager.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py index 1d10ad55deb8..78df0d304096 100644 --- a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py +++ b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py @@ -12,8 +12,8 @@ TP_SIZE = 1 BATCH_SIZE = 2 -MAX_INPUT_LEN = 5 -MAX_OUTPUT_LEN = 16 +MAX_INPUT_LEN = 48 +MAX_OUTPUT_LEN = 256 def run(): @@ -39,8 +39,8 @@ def run(): dynamic_batch_manager = DynamicBatchManager( tp_engine=infer_engine, - max_total_token_num=42, - batch_max_tokens=42, + max_total_token_num=640, + batch_max_tokens=608, eos_id=0, log_stats=False, log_stats_interval=10, @@ -56,8 +56,6 @@ def run(): # test abort function dynamic_batch_manager.abort(req4.request_id) assert dynamic_batch_manager.req_queue.waiting_req_list[-1].aborted == True - - print(dynamic_batch_manager.req_queue.waiting_req_list) # test filter batch function, loop_for_fwd, _step, _init_batch and _prefill/_decode batch are tested batch = dynamic_batch_manager.req_queue.generate_new_batch()