From ba09b5437ff21e0cd6ce69b81af1c4fda1773543 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 25 Sep 2023 12:23:08 +0800 Subject: [PATCH 01/12] finish batch manager --- .../inference/dynamic_batching/__init__.py | 0 .../inference/dynamic_batching/infer_batch.py | 299 ++++++++++++++++++ .../inference/dynamic_batching/io_struct.py | 133 ++++++++ .../inference/dynamic_batching/req_queue.py | 68 ++++ .../dynamic_batching/sampling_params.py | 82 +++++ colossalai/inference/manager.py | 256 +++++++++++++++ 6 files changed, 838 insertions(+) create mode 100644 colossalai/inference/dynamic_batching/__init__.py create mode 100644 colossalai/inference/dynamic_batching/infer_batch.py create mode 100644 colossalai/inference/dynamic_batching/io_struct.py create mode 100644 colossalai/inference/dynamic_batching/req_queue.py create mode 100644 colossalai/inference/dynamic_batching/sampling_params.py create mode 100644 colossalai/inference/manager.py diff --git a/colossalai/inference/dynamic_batching/__init__.py b/colossalai/inference/dynamic_batching/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/dynamic_batching/infer_batch.py b/colossalai/inference/dynamic_batching/infer_batch.py new file mode 100644 index 000000000000..011c4a4221ef --- /dev/null +++ b/colossalai/inference/dynamic_batching/infer_batch.py @@ -0,0 +1,299 @@ +import torch +import numpy as np +import collections + +from lightllm.common.configs.config import setting +from dataclasses import dataclass +from typing import List, Dict +from lightllm.common.mem_manager import MemoryManager +from lightllm.utils.infer_utils import mark_start, mark_end + + +class InferSamplingParams: + + def __init__( + self, + do_sample: bool = False, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + vocab_size: int = -1, + ) -> None: + self.do_sample = do_sample + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + if self.top_k == -1: + self.top_k = vocab_size + return + + +@dataclass +class InferBatch: + batch_id: int + requests: List + requests_idx_mapping: Dict[int, int] + + input_ids: torch.Tensor + + all_input_ids: List[List[int]] + input_lengths: List[int] + + out_token_id_counts: List + sampling_param_list : List[InferSamplingParams] + + input_ids: torch.Tensor + + nopad_total_token_num: int + nopad_max_len_in_batch: int + nopad_b_loc: torch.Tensor + nopad_b_start_loc: torch.Tensor + nopad_b_seq_len: torch.Tensor + mem_manager: MemoryManager + + @classmethod + @torch.no_grad() + def init_batch(cls, batch_id, requests, dtype: torch.dtype, device: torch.device, mem_manager:MemoryManager, vocab_size: int): + + input_lengths = [] + all_input_ids = [] + requests_idx_mapping = {} + + out_token_id_counts = [] + sampling_param_list = [] + + nopad_total_token_num = 0 + nopad_max_len_in_batch = 0 + nopad_b_loc = torch.empty((len(requests), setting['max_req_total_len'] + 12), dtype=torch.long, device='cuda') + nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device='cuda') + for i, r in enumerate(requests): + # request id -> idx in list mapping + requests_idx_mapping[r['request_id']] = i + + tokenized_input = r['input_id'] + + input_length = len(tokenized_input) + input_lengths.append(input_length) + all_input_ids.append(tokenized_input) + out_token_id_counts.append(collections.defaultdict(int)) + + # postprocessor + sampling_param = r["sampling_param"] + sampling_param["vocab_size"] = vocab_size + sampling_param_list.append(InferSamplingParams(**sampling_param)) + + nopad_total_token_num += input_length + nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length) + + + nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda") + nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] + if len(requests) > 1: + input_ids = np.concatenate(all_input_ids, dtype=np.int64) + else: + input_ids = all_input_ids[0] + + # Create tensors on device + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + + return cls( + batch_id=batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=out_token_id_counts, + sampling_param_list=sampling_param_list, + mem_manager=mem_manager, + ) + + @torch.no_grad() + def free_self(self): + remove_index = [] + for idx in range(len(self)): + remove_index.append(self.nopad_b_loc[idx, (self.nopad_max_len_in_batch - 1) - (self.nopad_b_seq_len[idx] - 1): (self.nopad_max_len_in_batch - 1)]) + remove_index = torch.cat(remove_index, dim=-1) + self.mem_manager.free(remove_index) + return + + @torch.no_grad() + def filter(self, request_ids: List[int]): + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + requests_idx_mapping = {} + indices = [] + requests = [] + all_input_ids = [] + input_lengths = [] + + nopad_total_token_num = 0 + nopad_max_len_in_batch = 0 + nopad_b_loc = torch.empty((len(request_ids), setting['max_req_total_len'] + 12), dtype=torch.long, device='cuda') + nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device='cuda') + nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device='cuda') + + left_idx = [] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + left_idx.append(idx) + + left_idx_set = set(left_idx) + remove_index = [] + for idx in range(len(self)): + if idx not in left_idx_set: + remove_index.append(self.nopad_b_loc[idx, (self.nopad_max_len_in_batch - 1) - (self.nopad_b_seq_len[idx] - 1): (self.nopad_max_len_in_batch - 1)]) + remove_index = torch.cat(remove_index, dim=-1) + + self.mem_manager.free(remove_index) + + nopad_max_len_in_batch = 0 + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + indices.append(idx) + + nopad_b_seq_len[:] = self.nopad_b_seq_len[indices] + nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item() + nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] + nopad_total_token_num = torch.sum(nopad_b_seq_len).item() + + nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[indices, (self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1): (self.nopad_max_len_in_batch - 1)] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + requests.append(self.requests[idx]) + all_input_ids.append(self.all_input_ids[idx]) + input_lengths.append(self.input_lengths[idx]) + + input_ids = self.input_ids[indices] + + return InferBatch( + batch_id=self.batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices], + sampling_param_list=[self.sampling_param_list[_i] for _i in indices], + mem_manager=self.mem_manager + ) + + @classmethod + @torch.no_grad() + def merge(cls, batch1, batch2): + requests = batch1.requests + batch2.requests + requests_idx_mapping = {} + new_batch_size = len(batch1) + len(batch2) + + input_ids = batch1.input_ids.new_empty(new_batch_size) + all_input_ids = [] + input_lengths = [] + out_token_id_counts=[] + sampling_param_list=[] + + cumulative_batch_size = 0 + nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num + nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2 .nopad_max_len_in_batch) + + nopad_b_loc = torch.empty((new_batch_size, setting['max_req_total_len'] + 12), dtype=torch.long, device='cuda') + nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device='cuda') + nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device='cuda') + nopad_start_loc_len_temp = 0 + batches = [batch1, batch2] + for i, batch in enumerate(batches): + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + cumulative_batch_size + start_index = cumulative_batch_size + end_index = cumulative_batch_size + len(batch) + input_ids[start_index:end_index] = batch.input_ids + nopad_b_seq_len[start_index: end_index] = batch.nopad_b_seq_len + nopad_b_start_loc[start_index: end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp + nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1] + nopad_b_loc[start_index: end_index, nopad_max_len_in_batch - batch.nopad_max_len_in_batch: nopad_max_len_in_batch - + 1] = batch.nopad_b_loc[:, :batch.nopad_max_len_in_batch - 1] + + all_input_ids.extend(batch.all_input_ids) + + input_lengths.extend(batch.input_lengths) + out_token_id_counts.extend(batch.out_token_id_counts) + sampling_param_list.extend(batch.sampling_param_list) + # Update + cumulative_batch_size += len(batch) + + nopad_b_loc[:, nopad_max_len_in_batch - 1] = nopad_total_token_num - \ + new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device='cuda') + return InferBatch( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=out_token_id_counts, + sampling_param_list=sampling_param_list, + mem_manager=batches[0].mem_manager + ) + + def __len__(self): + return len(self.requests) + + + def get_post_sample_tensors(self): + presence_penalties: List[float] = [] + frequency_penalties: List[float] = [] + temperatures: List[float] = [] + top_ps: List[float] = [] + top_ks: List[int] = [] + p_token_ids: List[int] = [] + p_token_counts: List[int] = [] + p_seq_len: List[int] = [0,] + p_max_len_in_batch: int = 0 + for i, id_to_count in enumerate(self.out_token_id_counts): + sample_param = self.sampling_param_list[i] + presence_penalties.append(sample_param.presence_penalty) + frequency_penalties.append(sample_param.frequency_penalty) + temperatures.append(sample_param.temperature) + top_ps.append(sample_param.top_p) + top_ks.append(sample_param.top_k) + + for token_id, count in id_to_count.items(): + p_token_ids.append(token_id) + p_token_counts.append(count) + p_seq_len.append(len(id_to_count)) + p_max_len_in_batch = max(p_max_len_in_batch, len(id_to_count)) + + presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda") + frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda") + temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda") + top_ps = torch.tensor(top_ps, dtype=torch.float, device="cuda") + top_ks = torch.tensor(top_ks, dtype=torch.int32, device="cuda") + p_token_ids = torch.tensor(p_token_ids, dtype=torch.int32, device="cuda") + p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda") + p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda") + p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32) + return presence_penalties, frequency_penalties, temperatures, top_ps, top_ks, p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch \ No newline at end of file diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py new file mode 100644 index 000000000000..5324ee262986 --- /dev/null +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -0,0 +1,133 @@ +from .sampling_params import SamplingParams +from typing import Dict, List, Optional, Tuple +import asyncio + + +class Req: + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams): + self.request_id = request_id + self.prompt_ids = prompt_ids + self.input_len = len(prompt_ids) + self.max_output_len = sample_params.max_new_tokens + self.sample_params = sample_params + self.output_ids = [] + self.output_metadata_list = [] + self.has_generate_finished = False + self.aborted = False + + def to_rpc_obj(self): + return {"request_id": self.request_id, + "input_id": self.prompt_ids, + "output_len": self.max_output_len, + "sampling_param": self.sample_params.to_dict() } + + def to_req_detokenization_state(self): + out = ReqDetokenizationState(self.request_id, self.prompt_ids, self.max_output_len, self.sample_params.ignore_eos) + if self.output_metadata_list: + out.gen_metadata.update(self.output_metadata_list[-1]) + return out + + def stop_sequences_matched(self): + for stop_token_ids in self.sample_params.stop_sequences: + stop_len = len(stop_token_ids) + if stop_len > 0: + if len(self.output_ids) >= stop_len: + if all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)): + return True + return False + + def __repr__(self): + return (f"request_id(n={self.request_id}, " + f"prompt_ids={self.prompt_ids}, ") + + +class ReqDetokenizationState: + def __init__( + self, + request_id: str, + prompt_ids: List[int], + max_output_len: int, + ignore_eos: bool, + ) -> None: + self.request_id = request_id + self.prompt_ids = prompt_ids + self.output_ids = [] + self.output_tokens = [] + self.output_str = "" + self.sub_texts = [] + self.current_sub_text = [] + self.max_output_len = max_output_len + self.ignore_eos = ignore_eos + self.gen_metadata = {} + +class Batch: + def __init__(self, batch_id, reqs: List[Req]): + self.batch_id = batch_id + self.reqs = reqs + self.id_to_reqs = {req.request_id: req for req in reqs} + + def input_tokens(self): + batch_input_tokens = 0 + for req in self.reqs: + batch_input_tokens += req.input_len + return batch_input_tokens + + def calcu_max_tokens(self): + tokens = 0 + for req in self.reqs: + tokens += req.input_len + req.max_output_len + return tokens + + def calcu_used_tokens(self): + tokens = 0 + for req in self.reqs: + tokens += req.input_len + len(req.output_ids) + return tokens + + def mark_finished_req(self, eos_id): + has_new_finish = False + for req in self.reqs: + if req.stop_sequences_matched(): + req.has_generate_finished = True + has_new_finish = True + if req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False: + req.has_generate_finished = True + has_new_finish = True + if len(req.output_ids) >= req.max_output_len or req.aborted: + req.has_generate_finished = True + has_new_finish = True + return has_new_finish + + def filter_finished(self): + unfinished_req = [] + for req in self.reqs: + if not req.has_generate_finished: + unfinished_req.append(req) + self.reqs = unfinished_req + self.id_to_reqs = {req.request_id: req for req in self.reqs} + + def is_clear(self): + return len(self.reqs) == 0 + + def merge(self, mini_batch): + for _req in mini_batch.reqs: + self.reqs.append(_req) + self.id_to_reqs = {req.request_id: req for req in self.reqs} + return + + def __repr__(self): + return (f"batch_id={self.batch_id}, " + f"reqs={self.reqs}, ") + +class BatchTokenIdOut: + def __init__(self): + self.reqs_infs: List[Tuple[str, int, Dict, bool, bool]] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state] + +class BatchStrOut: + def __init__(self): + self.reqs_infs: List[Tuple[str, str, Dict, bool, bool]] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state] + +class AbortReq: + def __init__(self, req_id): + self.req_id = req_id + diff --git a/colossalai/inference/dynamic_batching/req_queue.py b/colossalai/inference/dynamic_batching/req_queue.py new file mode 100644 index 000000000000..1214140d2841 --- /dev/null +++ b/colossalai/inference/dynamic_batching/req_queue.py @@ -0,0 +1,68 @@ +import uuid +import asyncio +import numpy as np +from typing import List +from ..io_struct import Batch, Req +from lightllm.utils.infer_utils import calculate_time + + +class ReqQueue: + + def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size) -> None: + self.max_total_tokens = max_total_tokens + assert batch_max_tokens is not None + self.batch_max_tokens = batch_max_tokens + self.running_max_req_size = running_max_req_size + self.waiting_req_list: List[Req] = [] + + def append(self, req): + self.waiting_req_list.append(req) + return + + def _init_cache_list(self, current_batch:Batch): + if current_batch is not None: + self.cache_len_list = [(req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1) for req in current_batch.reqs] + else: + self.cache_len_list = [] + + # @calculate_time(show=True, min_cost_ms=0.1) + def _can_add_new_req(self, req): + self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis + self.cache_len_list.sort(key=lambda x: -x[1]) + + left_out_len_array = np.array([e[1] for e in self.cache_len_list]) + # assert left_out_len_array.min() >= 0 + has_run_len_array = np.array([e[0] for e in self.cache_len_list]) + cum_run_len_array = np.cumsum(has_run_len_array) + size_array = np.arange(1, len(self.cache_len_list) + 1, 1) + + need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() + if need_max_token_num < self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size: + return True + else: + return False + + def generate_new_batch(self, current_batch:Batch): + if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size: + return None + + self._init_cache_list(current_batch) + can_run_list = [] + new_batch_total_tokens = 0 + aborted_count = 0 + for req in self.waiting_req_list: + if req.aborted: + aborted_count += 1 + continue + if self._can_add_new_req(req) and new_batch_total_tokens + req.input_len <= self.batch_max_tokens: + can_run_list.append(req) + new_batch_total_tokens += req.input_len + else: + break + + if len(can_run_list) != 0: + new_batch = Batch(uuid.uuid4().hex, can_run_list) + self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count:] + return new_batch + else: + return None diff --git a/colossalai/inference/dynamic_batching/sampling_params.py b/colossalai/inference/dynamic_batching/sampling_params.py new file mode 100644 index 000000000000..8af532dfa39c --- /dev/null +++ b/colossalai/inference/dynamic_batching/sampling_params.py @@ -0,0 +1,82 @@ +"""Sampling parameters for text generation.""" +from typing import List, Optional, Union + +_SAMPLING_EPS = 1e-5 + + +class SamplingParams: + + def __init__( + self, + do_sample: bool = False, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, # -1 is for all + ignore_eos: bool = False, + max_new_tokens: int = 16, + stop_sequences: Optional[Union[str, List[str]]] = None # 停止句子条件 + ) -> None: + self.do_sample = do_sample + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.ignore_eos = ignore_eos + self.max_new_tokens = max_new_tokens + self.stop_sequences = stop_sequences + if self.do_sample == False: + self.temperature = 1.0 + self.top_p = 1.0 + self.top_k = 1 + if self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS: # temperature is too slow, change to greedy search + self.temperature = 1.0 + self.top_k = 1 + return + + def verify(self): + if self.presence_penalty < 0.0: + raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}") + if self.frequency_penalty < 0.0: + raise ValueError(f"frequency_penalty must >= 0.0, got {self.frequency_penalty}") + if self.temperature <= 0.0: + raise ValueError(f"temperature must > 0.0, got {self.temperature}") + if self.top_p <= 0.0 or self.top_p > 1.0: + raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}") + if self.top_k < -1 or self.top_k == 0: + raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.") + if self.max_new_tokens < 1: + raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.") + return + + def stop_sentences_to_token_ids(self, tokenizer): + if self.stop_sequences is None: + self.stop_sequences = [] + else: + if isinstance(self.stop_sequences, str): + self.stop_sequences = [self.stop_sequences] + new_stop_sequences = [] + for stop_str in self.stop_sequences: + stop_str_ids = tokenizer.encode(stop_str) + if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id + stop_str_ids = stop_str_ids[1:] + if len(stop_str_ids) > 0: + new_stop_sequences.append(stop_str_ids) + self.stop_sequences = new_stop_sequences + return + + def to_dict(self): + ret = {} + ret["do_sample"] = self.do_sample + ret["presence_penalty"] = self.presence_penalty + ret["frequency_penalty"] = self.frequency_penalty + ret["temperature"] = self.temperature + ret["top_p"] = self.top_p + ret["top_k"] = self.top_k + # if self.ignore_eos is not None: + # ret["ignore_eos"] = self.ignore_eos + # if self.max_tokens is not None: + # ret["max_tokens"] = self.max_tokens + return ret diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py new file mode 100644 index 000000000000..09dfae0632f6 --- /dev/null +++ b/colossalai/inference/manager.py @@ -0,0 +1,256 @@ +import time +import uvloop +import asyncio +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +import zmq +import zmq.asyncio +from typing import Dict, List, Optional +from dynamic_batching.infer_batch import InferBatch +from ..sampling_params import SamplingParams +from inference.dynamic_batching import Req, Batch +from .model_infer.model_rpc import start_model_process, ModelRpcClient +from dynamic_batching.req_queue import ReqQueue +from lightllm.utils.infer_utils import calculate_time +from dynamic_batching.io_struct import BatchTokenIdOut, AbortReq +from .stats import Stats + +class DynamicBatchManager: + + def __init__(self,tp_engine, world_size, max_total_token_num, batch_max_tokens, running_max_req_size, eos_id, + router_port, detokenization_port, model_rpc_ports, log_stats=True, log_stats_interval=10): + self.engine = tp_engine + self.world_size = world_size + self.max_total_token_num = max_total_token_num + + self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size) + # all the inputs should be put into req_queue + + self.running_batch: Batch = None + self.eos_id = eos_id + self.has_wait_tokens = 0 + self.max_wait_tokens = 10 + + context = zmq.asyncio.Context(2) + + # self.send_to_detokenization = context.socket(zmq.PUSH) + # self.send_to_detokenization.connect(f"tcp://127.0.0.1:{detokenization_port}") + + self.stats_tool = Stats(log_stats, log_stats_interval) + + # In Torch serve, model is initialized before manage + async def wait_to_model_ready(self): + pass + + def add_req( + self, + prompt_ids: List[int], + sampling_params: SamplingParams, + request_id: str + ): + req = Req(request_id, prompt_ids, sampling_params) + self.req_queue.append(req) + self.send_to_detokenization.send_pyobj(req.to_req_detokenization_state()) + return + + async def abort(self, request_id): + if self.running_batch is not None: + for req in self.running_batch.reqs: + if req.request_id == request_id: + req.has_generate_finished = True + req.aborted = True + for req in self.req_queue.waiting_req_list: + if req.request_id == request_id: + req.has_generate_finished = True + req.aborted = True + return + + async def loop_for_fwd(self,): + counter_count = 0 + while True: + await self._step() + counter_count += 1 + if self.running_batch is not None: + if counter_count % 50 == 0: + print("current batch size:", len(self.running_batch.reqs), "token used ratio:", self.running_batch.calcu_used_tokens() / self.max_total_token_num) + pass + self.stats_tool.print_stats() + + if self.running_batch is None: + await asyncio.sleep(0.01) # 10ms + + async def _step(self): + """ + handle the requests + """ + # 删除所有已经 finished 的 req + if self.running_batch is None: + new_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_batch is not None: + self.stats_tool.count_prompt_tokens(new_batch) + self.running_batch = new_batch + await self._prefill_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens = 0 + return + + if self.has_wait_tokens < self.max_wait_tokens: + self.stats_tool.count_output_tokens(self.running_batch) + await self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + return + else: + new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_mini_batch is not None: + self.stats_tool.count_prompt_tokens(new_mini_batch) + await self._prefill_batch(new_mini_batch) + if not new_mini_batch.is_clear(): + await self._merge_batch(self.running_batch, new_mini_batch) + self.running_batch.merge(new_mini_batch) + self.has_wait_tokens = 0 + else: + self.stats_tool.count_output_tokens(self.running_batch) + await self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + + return + + async def _init_batch(self, batch: Batch): + reqs = [r.to_rpc_obj() for r in batch.reqs] + #rets = [self.model_rpcs[tp_rank].init_batch(batch.batch_id, reqs) for tp_rank in range(self.world_size)] + if self.world_size != 1: + batch_id, reqs, dtype = obtain(batch_id), obtain(reqs), obtain(dtype) + import torch + if dtype == "fp16": + dtype = torch.float16 + else: + assert False, "error dtype" + batch_data = InferBatch.init_batch(batch_id, reqs, dtype, torch.cuda.current_device(), self.engine.model.mem_manager, self.engine.model.vocab_size) + self.cache[batch_id] = batch_data + return + + async def _prefill_batch(self, batch): + await self._init_batch(batch) + # rets = [self.model_rpcs[tp_rank].foward(batch.batch_id) for tp_rank in range(self.world_size)] + # TODO: figure out if cache and batch id is needed + rets = self.engine.prefill(batch.batch_id) + ans = await asyncio.gather(*rets) + + if self.world_size != 1: + req_to_out_token_id = obtain(ans[0]) + else: + req_to_out_token_id = ans[0] + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id) + # self._send_to_detokenization_proc(batch, req_to_out_token_id) + await self._handle_finish_req(batch, has_new_finished_req) + return + + async def _decode_batch(self, batch:Batch): + # rets = [self.model_rpcs[tp_rank].decode_batch(batch.batch_id) for tp_rank in range(self.world_size)] + rets = self.engine.decode(batch.batch_id) + ans = await asyncio.gather(*rets) + if self.world_size != 1: + req_to_out_token_id = obtain(ans[0]) # gather or something + else: + req_to_out_token_id = ans[0] + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id) + #self._send_to_detokenization_proc(batch, req_to_out_token_id) + await self._handle_finish_req(batch, has_new_finished_req) + return + + async def _filter_batch(self, batch: Batch): + req_id_list = [r.request_id for r in batch.reqs] + filter_batch = batch.filter(req_id_list) + batch = filter_batch + # rets = [self.model_rpcs[tp_rank].filter_batch(batch.batch_id, req_id_list) for tp_rank in range(self.world_size)] + # await asyncio.gather(*rets) + return + + async def _merge_batch(self, batch1, batch2): + # rets = [self.model_rpcs[tp_rank].merge_batch(batch1.batch_id, batch2.batch_id) for tp_rank in range(self.world_size)] + # await asyncio.gather(*rets) + m_batch = InferBatch.merge(batch1, batch2) + del batch2 + batch1 = m_batch + return + + async def _remove_batch(self, batch): + batch.free_self() + del batch + # rets = [self.model_rpcs[tp_rank].remove_batch(batch.batch_id) for tp_rank in range(self.world_size)] + # await asyncio.gather(*rets) + return + + async def _handle_finish_req(self, batch: Batch, has_new_finished_req): + if has_new_finished_req: + batch.filter_finished() + if batch.is_clear(): + await self._remove_batch(batch) + else: + await self._filter_batch(batch) + return + + def _filter_runing_batch(self): + if self.running_batch is not None and self.running_batch.is_clear(): + self.running_batch = None + return + + def _add_token_id_to_req(self, batch: Batch, req_ans): + for req_id, (new_token_id, new_gen_metadata) in req_ans.items(): + req = batch.id_to_reqs[req_id] + req.output_ids.append(new_token_id) + req.output_metadata_list.append(new_gen_metadata) + return + + async def loop_for_netio_req(self): + while True: + recv_req = await self.recv_from_httpserver.recv_pyobj() + if isinstance(recv_req, tuple) and len(recv_req) == 3: + prompt_ids, sampling_params, request_id = recv_req + self.add_req(prompt_ids, sampling_params, request_id) + elif isinstance(recv_req, AbortReq): + abort_req = recv_req + request_id = abort_req.req_id + await self.abort(request_id) + self.send_to_detokenization.send_pyobj(abort_req) + else: + assert False, f"Error Req Inf {recv_req}" + + def clean_up(self): + #this logic should be implemented + pass + +def start_router_process(args, router_port, detokenization_port, model_rpc_ports, mode, pipe_writer): + try: + batch_manager = DynamicBatchManager( + args.model_dir, + world_size=args.tp, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + running_max_req_size=args.running_max_req_size, + eos_id=args.eos_id, + router_port=router_port, + detokenization_port=detokenization_port, + model_rpc_ports=model_rpc_ports, + mode=mode, + log_stats = not args.disable_log_stats, + log_stats_interval = args.log_stats_interval) + + asyncio.run(batch_manager.wait_to_model_ready()) + except Exception as e: + import traceback + err_str = '\n'.join(traceback.format_exception(e)) + pipe_writer.send(err_str) + batch_manager.clean_up() + raise + + pipe_writer.send('init ok') + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.create_task(batch_manager.loop_for_fwd()) + loop.run_until_complete(batch_manager.loop_for_netio_req()) + return From 58f24c8c16b02789960e91dbb6a5e18898a2e407 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 26 Sep 2023 11:54:58 +0800 Subject: [PATCH 02/12] 1 --- .../inference/dynamic_batching/infer_batch.py | 137 ++++++++----- .../inference/dynamic_batching/req_queue.py | 34 ++-- colossalai/inference/dynamic_batching/stas.py | 43 ++++ colossalai/inference/manager.py | 192 +++++++++++------- .../inference/tensor_parallel/engine.py | 82 ++++++-- .../tensor_parallel/modeling/__init__.py | 2 - colossalai/kernel/triton/__init__.py | 52 ++--- tests/kit/model_zoo/torchrec/__init__.py | 2 +- 8 files changed, 361 insertions(+), 183 deletions(-) create mode 100644 colossalai/inference/dynamic_batching/stas.py diff --git a/colossalai/inference/dynamic_batching/infer_batch.py b/colossalai/inference/dynamic_batching/infer_batch.py index 011c4a4221ef..b6781e0347b7 100644 --- a/colossalai/inference/dynamic_batching/infer_batch.py +++ b/colossalai/inference/dynamic_batching/infer_batch.py @@ -1,16 +1,14 @@ -import torch -import numpy as np import collections +from dataclasses import dataclass +from typing import Dict, List +import numpy as np +import torch from lightllm.common.configs.config import setting -from dataclasses import dataclass -from typing import List, Dict from lightllm.common.mem_manager import MemoryManager -from lightllm.utils.infer_utils import mark_start, mark_end class InferSamplingParams: - def __init__( self, do_sample: bool = False, @@ -42,9 +40,9 @@ class InferBatch: all_input_ids: List[List[int]] input_lengths: List[int] - + out_token_id_counts: List - sampling_param_list : List[InferSamplingParams] + sampling_param_list: List[InferSamplingParams] input_ids: torch.Tensor @@ -57,24 +55,25 @@ class InferBatch: @classmethod @torch.no_grad() - def init_batch(cls, batch_id, requests, dtype: torch.dtype, device: torch.device, mem_manager:MemoryManager, vocab_size: int): - + def init_batch( + cls, batch_id, requests, dtype: torch.dtype, device: torch.device, mem_manager: MemoryManager, vocab_size: int + ): input_lengths = [] all_input_ids = [] requests_idx_mapping = {} - + out_token_id_counts = [] sampling_param_list = [] - + nopad_total_token_num = 0 nopad_max_len_in_batch = 0 - nopad_b_loc = torch.empty((len(requests), setting['max_req_total_len'] + 12), dtype=torch.long, device='cuda') - nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device='cuda') + nopad_b_loc = torch.empty((len(requests), setting["max_req_total_len"] + 12), dtype=torch.long, device="cuda") + nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda") for i, r in enumerate(requests): # request id -> idx in list mapping - requests_idx_mapping[r['request_id']] = i + requests_idx_mapping[r["request_id"]] = i - tokenized_input = r['input_id'] + tokenized_input = r["input_id"] input_length = len(tokenized_input) input_lengths.append(input_length) @@ -85,11 +84,10 @@ def init_batch(cls, batch_id, requests, dtype: torch.dtype, device: torch.device sampling_param = r["sampling_param"] sampling_param["vocab_size"] = vocab_size sampling_param_list.append(InferSamplingParams(**sampling_param)) - + nopad_total_token_num += input_length nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length) - - + nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda") nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] if len(requests) > 1: @@ -116,16 +114,22 @@ def init_batch(cls, batch_id, requests, dtype: torch.dtype, device: torch.device sampling_param_list=sampling_param_list, mem_manager=mem_manager, ) - + @torch.no_grad() def free_self(self): remove_index = [] for idx in range(len(self)): - remove_index.append(self.nopad_b_loc[idx, (self.nopad_max_len_in_batch - 1) - (self.nopad_b_seq_len[idx] - 1): (self.nopad_max_len_in_batch - 1)]) + remove_index.append( + self.nopad_b_loc[ + idx, + (self.nopad_max_len_in_batch - 1) + - (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1), + ] + ) remove_index = torch.cat(remove_index, dim=-1) self.mem_manager.free(remove_index) return - + @torch.no_grad() def filter(self, request_ids: List[int]): if len(request_ids) == 0: @@ -140,42 +144,53 @@ def filter(self, request_ids: List[int]): nopad_total_token_num = 0 nopad_max_len_in_batch = 0 - nopad_b_loc = torch.empty((len(request_ids), setting['max_req_total_len'] + 12), dtype=torch.long, device='cuda') - nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device='cuda') - nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device='cuda') + nopad_b_loc = torch.empty( + (len(request_ids), setting["max_req_total_len"] + 12), dtype=torch.long, device="cuda" + ) + nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda") + nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda") left_idx = [] for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] left_idx.append(idx) - + left_idx_set = set(left_idx) remove_index = [] for idx in range(len(self)): if idx not in left_idx_set: - remove_index.append(self.nopad_b_loc[idx, (self.nopad_max_len_in_batch - 1) - (self.nopad_b_seq_len[idx] - 1): (self.nopad_max_len_in_batch - 1)]) + remove_index.append( + self.nopad_b_loc[ + idx, + (self.nopad_max_len_in_batch - 1) + - (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1), + ] + ) remove_index = torch.cat(remove_index, dim=-1) - + self.mem_manager.free(remove_index) nopad_max_len_in_batch = 0 for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] indices.append(idx) - + nopad_b_seq_len[:] = self.nopad_b_seq_len[indices] nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item() nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] nopad_total_token_num = torch.sum(nopad_b_seq_len).item() - - nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[indices, (self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1): (self.nopad_max_len_in_batch - 1)] + + nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[ + indices, + (self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1) : (self.nopad_max_len_in_batch - 1), + ] for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] requests_idx_mapping[request_id] = i requests.append(self.requests[idx]) all_input_ids.append(self.all_input_ids[idx]) input_lengths.append(self.input_lengths[idx]) - + input_ids = self.input_ids[indices] return InferBatch( @@ -192,7 +207,7 @@ def filter(self, request_ids: List[int]): nopad_b_seq_len=nopad_b_seq_len, out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices], sampling_param_list=[self.sampling_param_list[_i] for _i in indices], - mem_manager=self.mem_manager + mem_manager=self.mem_manager, ) @classmethod @@ -205,16 +220,16 @@ def merge(cls, batch1, batch2): input_ids = batch1.input_ids.new_empty(new_batch_size) all_input_ids = [] input_lengths = [] - out_token_id_counts=[] - sampling_param_list=[] + out_token_id_counts = [] + sampling_param_list = [] cumulative_batch_size = 0 nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num - nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2 .nopad_max_len_in_batch) - - nopad_b_loc = torch.empty((new_batch_size, setting['max_req_total_len'] + 12), dtype=torch.long, device='cuda') - nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device='cuda') - nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device='cuda') + nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2.nopad_max_len_in_batch) + + nopad_b_loc = torch.empty((new_batch_size, setting["max_req_total_len"] + 12), dtype=torch.long, device="cuda") + nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") + nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") nopad_start_loc_len_temp = 0 batches = [batch1, batch2] for i, batch in enumerate(batches): @@ -226,11 +241,13 @@ def merge(cls, batch1, batch2): start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) input_ids[start_index:end_index] = batch.input_ids - nopad_b_seq_len[start_index: end_index] = batch.nopad_b_seq_len - nopad_b_start_loc[start_index: end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp + nopad_b_seq_len[start_index:end_index] = batch.nopad_b_seq_len + nopad_b_start_loc[start_index:end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1] - nopad_b_loc[start_index: end_index, nopad_max_len_in_batch - batch.nopad_max_len_in_batch: nopad_max_len_in_batch - - 1] = batch.nopad_b_loc[:, :batch.nopad_max_len_in_batch - 1] + nopad_b_loc[ + start_index:end_index, + nopad_max_len_in_batch - batch.nopad_max_len_in_batch : nopad_max_len_in_batch - 1, + ] = batch.nopad_b_loc[:, : batch.nopad_max_len_in_batch - 1] all_input_ids.extend(batch.all_input_ids) @@ -239,9 +256,10 @@ def merge(cls, batch1, batch2): sampling_param_list.extend(batch.sampling_param_list) # Update cumulative_batch_size += len(batch) - - nopad_b_loc[:, nopad_max_len_in_batch - 1] = nopad_total_token_num - \ - new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device='cuda') + + nopad_b_loc[:, nopad_max_len_in_batch - 1] = ( + nopad_total_token_num - new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device="cuda") + ) return InferBatch( batch_id=batches[0].batch_id, requests=requests, @@ -256,13 +274,12 @@ def merge(cls, batch1, batch2): nopad_b_seq_len=nopad_b_seq_len, out_token_id_counts=out_token_id_counts, sampling_param_list=sampling_param_list, - mem_manager=batches[0].mem_manager + mem_manager=batches[0].mem_manager, ) def __len__(self): return len(self.requests) - - + def get_post_sample_tensors(self): presence_penalties: List[float] = [] frequency_penalties: List[float] = [] @@ -271,7 +288,9 @@ def get_post_sample_tensors(self): top_ks: List[int] = [] p_token_ids: List[int] = [] p_token_counts: List[int] = [] - p_seq_len: List[int] = [0,] + p_seq_len: List[int] = [ + 0, + ] p_max_len_in_batch: int = 0 for i, id_to_count in enumerate(self.out_token_id_counts): sample_param = self.sampling_param_list[i] @@ -280,13 +299,13 @@ def get_post_sample_tensors(self): temperatures.append(sample_param.temperature) top_ps.append(sample_param.top_p) top_ks.append(sample_param.top_k) - + for token_id, count in id_to_count.items(): p_token_ids.append(token_id) p_token_counts.append(count) p_seq_len.append(len(id_to_count)) p_max_len_in_batch = max(p_max_len_in_batch, len(id_to_count)) - + presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda") frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda") temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda") @@ -296,4 +315,14 @@ def get_post_sample_tensors(self): p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda") p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda") p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32) - return presence_penalties, frequency_penalties, temperatures, top_ps, top_ks, p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch \ No newline at end of file + return ( + presence_penalties, + frequency_penalties, + temperatures, + top_ps, + top_ks, + p_token_ids, + p_token_counts, + p_cumsum_seq_len, + p_max_len_in_batch, + ) diff --git a/colossalai/inference/dynamic_batching/req_queue.py b/colossalai/inference/dynamic_batching/req_queue.py index 1214140d2841..5982b3443838 100644 --- a/colossalai/inference/dynamic_batching/req_queue.py +++ b/colossalai/inference/dynamic_batching/req_queue.py @@ -1,51 +1,53 @@ import uuid -import asyncio -import numpy as np from typing import List -from ..io_struct import Batch, Req -from lightllm.utils.infer_utils import calculate_time +import numpy as np -class ReqQueue: +from .io_struct import Batch, Req + +class ReqQueue: def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size) -> None: self.max_total_tokens = max_total_tokens assert batch_max_tokens is not None self.batch_max_tokens = batch_max_tokens self.running_max_req_size = running_max_req_size self.waiting_req_list: List[Req] = [] - + def append(self, req): self.waiting_req_list.append(req) return - - def _init_cache_list(self, current_batch:Batch): + + def _init_cache_list(self, current_batch: Batch): if current_batch is not None: - self.cache_len_list = [(req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1) for req in current_batch.reqs] + self.cache_len_list = [ + (req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1) + for req in current_batch.reqs + ] else: self.cache_len_list = [] - + # @calculate_time(show=True, min_cost_ms=0.1) def _can_add_new_req(self, req): - self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis + self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis self.cache_len_list.sort(key=lambda x: -x[1]) - + left_out_len_array = np.array([e[1] for e in self.cache_len_list]) # assert left_out_len_array.min() >= 0 has_run_len_array = np.array([e[0] for e in self.cache_len_list]) cum_run_len_array = np.cumsum(has_run_len_array) size_array = np.arange(1, len(self.cache_len_list) + 1, 1) - + need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() if need_max_token_num < self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size: return True else: return False - def generate_new_batch(self, current_batch:Batch): + def generate_new_batch(self, current_batch: Batch): if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size: return None - + self._init_cache_list(current_batch) can_run_list = [] new_batch_total_tokens = 0 @@ -62,7 +64,7 @@ def generate_new_batch(self, current_batch:Batch): if len(can_run_list) != 0: new_batch = Batch(uuid.uuid4().hex, can_run_list) - self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count:] + self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch else: return None diff --git a/colossalai/inference/dynamic_batching/stas.py b/colossalai/inference/dynamic_batching/stas.py new file mode 100644 index 000000000000..6d34183f47c4 --- /dev/null +++ b/colossalai/inference/dynamic_batching/stas.py @@ -0,0 +1,43 @@ +import time + + +class Stats: + def __init__(self, log_status, log_stats_interval) -> None: + self.log_stats = log_status + self.log_stats_interval = log_stats_interval + self.last_log_time = time.time() + self.all_tokens = 0 + self.output_tokens = 0 + self.prompt_tokens = 0 + return + + def count_prompt_tokens(self, run_batch): + if self.log_stats: + tokens = run_batch.input_tokens() + self.prompt_tokens += tokens + self.all_tokens += tokens + return + + def count_output_tokens(self, run_batch): + if self.log_stats: + tokens = len(run_batch.reqs) + self.output_tokens += tokens + self.all_tokens += tokens + return + + def print_stats(self): + if not self.log_stats: + return + + now = time.time() + if now - self.last_log_time > self.log_stats_interval: + print( + f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n" + f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n" + f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s" + ) + self.all_tokens = 0 + self.output_tokens = 0 + self.prompt_tokens = 0 + self.last_log_time = now + return diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 09dfae0632f6..e752859929a5 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,37 +1,52 @@ -import time -import uvloop +import argparse import asyncio + +import uvloop + +import colossalai + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -import zmq -import zmq.asyncio -from typing import Dict, List, Optional +from typing import List + from dynamic_batching.infer_batch import InferBatch -from ..sampling_params import SamplingParams -from inference.dynamic_batching import Req, Batch -from .model_infer.model_rpc import start_model_process, ModelRpcClient +from dynamic_batching.io_struct import Batch, Req from dynamic_batching.req_queue import ReqQueue -from lightllm.utils.infer_utils import calculate_time -from dynamic_batching.io_struct import BatchTokenIdOut, AbortReq -from .stats import Stats +from dynamic_batching.sampling_params import SamplingParams +from dynamic_batching.stas import Stats +from rpyc.utils.classic import obtain +from tensor_parallel.engine import TPInferEngine -class DynamicBatchManager: +from colossalai.shardformer import ShardConfig +from tests.kit.model_zoo import model_zoo +from tests.test_infer.test_llama_infer import init_to_get_rotary - def __init__(self,tp_engine, world_size, max_total_token_num, batch_max_tokens, running_max_req_size, eos_id, - router_port, detokenization_port, model_rpc_ports, log_stats=True, log_stats_interval=10): + +class DynamicBatchManager: + def __init__( + self, + tp_engine: TPInferEngine, + world_size, + max_total_token_num, + batch_max_tokens, + running_max_req_size, + eos_id, + router_port, + log_stats=True, + log_stats_interval=10, + ): self.engine = tp_engine self.world_size = world_size self.max_total_token_num = max_total_token_num - self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size) - # all the inputs should be put into req_queue + self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size) + # all the inputs should be put into req_queue self.running_batch: Batch = None self.eos_id = eos_id self.has_wait_tokens = 0 self.max_wait_tokens = 10 - - context = zmq.asyncio.Context(2) - + + # context = zmq.asyncio.Context(2) # self.send_to_detokenization = context.socket(zmq.PUSH) # self.send_to_detokenization.connect(f"tcp://127.0.0.1:{detokenization_port}") @@ -41,15 +56,9 @@ def __init__(self,tp_engine, world_size, max_total_token_num, batch_max_tokens, async def wait_to_model_ready(self): pass - def add_req( - self, - prompt_ids: List[int], - sampling_params: SamplingParams, - request_id: str - ): + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str): req = Req(request_id, prompt_ids, sampling_params) self.req_queue.append(req) - self.send_to_detokenization.send_pyobj(req.to_req_detokenization_state()) return async def abort(self, request_id): @@ -64,17 +73,23 @@ async def abort(self, request_id): req.aborted = True return - async def loop_for_fwd(self,): + async def loop_for_fwd(self): + print("why here") counter_count = 0 while True: + print("112221121212") await self._step() counter_count += 1 if self.running_batch is not None: if counter_count % 50 == 0: - print("current batch size:", len(self.running_batch.reqs), "token used ratio:", self.running_batch.calcu_used_tokens() / self.max_total_token_num) - pass + print( + "current batch size:", + len(self.running_batch.reqs), + "token used ratio:", + self.running_batch.calcu_used_tokens() / self.max_total_token_num, + ) self.stats_tool.print_stats() - + if self.running_batch is None: await asyncio.sleep(0.01) # 10ms @@ -113,28 +128,37 @@ async def _step(self): await self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 - + return async def _init_batch(self, batch: Batch): reqs = [r.to_rpc_obj() for r in batch.reqs] - #rets = [self.model_rpcs[tp_rank].init_batch(batch.batch_id, reqs) for tp_rank in range(self.world_size)] + # rets = [self.model_rpcs[tp_rank].init_batch(batch.batch_id, reqs) for tp_rank in range(self.world_size)] if self.world_size != 1: batch_id, reqs, dtype = obtain(batch_id), obtain(reqs), obtain(dtype) import torch + if dtype == "fp16": dtype = torch.float16 else: assert False, "error dtype" - batch_data = InferBatch.init_batch(batch_id, reqs, dtype, torch.cuda.current_device(), self.engine.model.mem_manager, self.engine.model.vocab_size) - self.cache[batch_id] = batch_data + # cache may be removed + batch_data = InferBatch.init_batch( + batch_id, + reqs, + dtype, + torch.cuda.current_device(), + self.engine.model.mem_manager, + self.engine.model.vocab_size, + ) + self.engine.cache[batch_id] = batch_data return async def _prefill_batch(self, batch): await self._init_batch(batch) # rets = [self.model_rpcs[tp_rank].foward(batch.batch_id) for tp_rank in range(self.world_size)] # TODO: figure out if cache and batch id is needed - rets = self.engine.prefill(batch.batch_id) + rets = self.engine._prefill_batch(batch.batch_id) ans = await asyncio.gather(*rets) if self.world_size != 1: @@ -147,37 +171,47 @@ async def _prefill_batch(self, batch): await self._handle_finish_req(batch, has_new_finished_req) return - async def _decode_batch(self, batch:Batch): + async def _decode_batch(self, batch: Batch): # rets = [self.model_rpcs[tp_rank].decode_batch(batch.batch_id) for tp_rank in range(self.world_size)] - rets = self.engine.decode(batch.batch_id) + rets = self.engine._decode_batch(batch.batch_id) ans = await asyncio.gather(*rets) if self.world_size != 1: - req_to_out_token_id = obtain(ans[0]) # gather or something + req_to_out_token_id = obtain(ans[0]) # gather or something else: req_to_out_token_id = ans[0] self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - #self._send_to_detokenization_proc(batch, req_to_out_token_id) + # self._send_to_detokenization_proc(batch, req_to_out_token_id) await self._handle_finish_req(batch, has_new_finished_req) return async def _filter_batch(self, batch: Batch): + batch_id = batch.batch_id req_id_list = [r.request_id for r in batch.reqs] + if self.world_size != 1: + batch_id, req_id_list = obtain(batch_id), obtain(req_id_list) + batch = self.engine.cache.pop(batch_id) filter_batch = batch.filter(req_id_list) - batch = filter_batch + del batch + self.engine.cache[batch_id] = filter_batch # rets = [self.model_rpcs[tp_rank].filter_batch(batch.batch_id, req_id_list) for tp_rank in range(self.world_size)] # await asyncio.gather(*rets) return async def _merge_batch(self, batch1, batch2): + batch1 = self.engine.cache.pop(batch1.batch_id) + batch2 = self.engine.cache.pop(batch2.batch_id) # rets = [self.model_rpcs[tp_rank].merge_batch(batch1.batch_id, batch2.batch_id) for tp_rank in range(self.world_size)] # await asyncio.gather(*rets) + m_batch = InferBatch.merge(batch1, batch2) + self.engine.cache[batch1.batch_id] = m_batch + del batch1 del batch2 - batch1 = m_batch return async def _remove_batch(self, batch): + batch = self.engine.cache.pop(batch.batch_id) batch.free_self() del batch # rets = [self.model_rpcs[tp_rank].remove_batch(batch.batch_id) for tp_rank in range(self.world_size)] @@ -197,60 +231,74 @@ def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): self.running_batch = None return - + def _add_token_id_to_req(self, batch: Batch, req_ans): for req_id, (new_token_id, new_gen_metadata) in req_ans.items(): req = batch.id_to_reqs[req_id] req.output_ids.append(new_token_id) req.output_metadata_list.append(new_gen_metadata) return - - async def loop_for_netio_req(self): - while True: - recv_req = await self.recv_from_httpserver.recv_pyobj() - if isinstance(recv_req, tuple) and len(recv_req) == 3: - prompt_ids, sampling_params, request_id = recv_req - self.add_req(prompt_ids, sampling_params, request_id) - elif isinstance(recv_req, AbortReq): - abort_req = recv_req - request_id = abort_req.req_id - await self.abort(request_id) - self.send_to_detokenization.send_pyobj(abort_req) - else: - assert False, f"Error Req Inf {recv_req}" def clean_up(self): - #this logic should be implemented - pass + # this logic should be implemented + pass -def start_router_process(args, router_port, detokenization_port, model_rpc_ports, mode, pipe_writer): + +def start_router_process(args, tp_engine, router_port): try: batch_manager = DynamicBatchManager( - args.model_dir, + tp_engine=tp_engine, world_size=args.tp, max_total_token_num=args.max_total_token_num, batch_max_tokens=args.batch_max_tokens, running_max_req_size=args.running_max_req_size, eos_id=args.eos_id, router_port=router_port, - detokenization_port=detokenization_port, - model_rpc_ports=model_rpc_ports, - mode=mode, - log_stats = not args.disable_log_stats, - log_stats_interval = args.log_stats_interval) - - asyncio.run(batch_manager.wait_to_model_ready()) + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + ) + except Exception as e: import traceback - err_str = '\n'.join(traceback.format_exception(e)) - pipe_writer.send(err_str) + + err_str = "\n".join(traceback.format_exception(e)) + print(err_str) + # may need use logger batch_manager.clean_up() raise - pipe_writer.send('init ok') - + print("start router process") + # batch_manager.loop_for_fwd() loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.create_task(batch_manager.loop_for_fwd()) - loop.run_until_complete(batch_manager.loop_for_netio_req()) + print("good") return + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--tp", type=int, help="tp_size", default=1) + parser.add_argument("--max_total_token_num", type=int, default=2048, help="max_total_token_num") + parser.add_argument("-b", "--batch_max_tokens", type=int, default=1024, help="max tokens of one batch") + parser.add_argument("--running_max_req_size", type=int, default=2, help="max request size of running batch ") + parser.add_argument("--eos_id", type=int, default=0, help="The end token of a seq") + parser.add_argument("--disable_log_stats", type=bool, default=False) + parser.add_argument("--log_stats_interval", type=int, default=10) + args = parser.parse_args() + + colossalai.launch(config={}, rank=0, world_size=1, host="localhost", port=8081, backend="nccl") + + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_casual_lm") + for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): + orig_model = model_fn() + init_to_get_rotary(orig_model.model, base=10000) + orig_model = orig_model.half() + data = data_gen_fn() + + shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) + BATCH_SIZE = 8 + MAX_INPUT_LEN = 12 + MAX_OUTPUT_LEN = 100 + infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + start_router_process(args=args, tp_engine=infer_engine, router_port=12345) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index d5ef37fee420..c7c111aa4c5d 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -1,8 +1,8 @@ from typing import Any, Callable, List, Optional, Union import torch -import torch.distributed as dist import torch.nn as nn +from dynamic_batching.infer_batch import InferBatch # may intergrate with batchinfer state from transformers import BloomForCausalLM, LlamaForCausalLM from transformers.generation import GenerationConfig from transformers.generation.stopping_criteria import StoppingCriteriaList @@ -90,6 +90,8 @@ def __init__( self.shard_config = shard_config self.model = None + self.cache = {} + # optimize the original model by sharding with ShardFormer self._optimize_model(model=model.to(device)) @@ -116,13 +118,15 @@ def _init_manager(self) -> None: def _post_init_gptq_buffer(self, model: nn.Module) -> None: from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear + HAS_GPTQ_CUDA = False try: from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() HAS_GPTQ_CUDA = True except ImportError: - warnings.warn('CUDA gptq is not installed') + warnings.warn("CUDA gptq is not installed") HAS_GPTQ_CUDA = False for name, submodule in model.named_modules(): @@ -130,8 +134,9 @@ def _post_init_gptq_buffer(self, model: nn.Module) -> None: self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8) if self.use_act_order: - self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures, - submodule.outfeatures) + self.max_inner_outer_dim = max( + self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures + ) self.bits = submodule.bits if not (HAS_GPTQ_CUDA and self.bits == 4): return @@ -141,15 +146,16 @@ def _post_init_gptq_buffer(self, model: nn.Module) -> None: max_input_len = self.max_input_len # The temp_state buffer is required to reorder X in the act-order case. # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim), - dtype=torch.float16, - device=torch.cuda.current_device()) - self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size), - dtype=torch.float16, - device=torch.cuda.current_device()) - - gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, - self.gptq_temp_dq_buffer) + self.gptq_temp_state_buffer = torch.zeros( + (max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device() + ) + self.gptq_temp_dq_buffer = torch.zeros( + (1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device() + ) + + gptq_cuda.prepare_buffers( + torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer + ) # Using the default from exllama repo here. matmul_recons_thd = 8 matmul_fused_remap = False @@ -367,6 +373,56 @@ def _update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device) infer_state.seq_len += 1 + def forward(self, batch_id, is_prefill): + batch: InferBatch = self.cache.pop(batch_id) + kwargs = { + "batch_size": len(batch), + "total_token_num": batch.nopad_total_token_num, + "max_len_in_batch": batch.nopad_max_len_in_batch, + "input_ids": batch.input_ids, + "b_loc": batch.nopad_b_loc, + "b_start_loc": batch.nopad_b_start_loc, + "b_seq_len": batch.nopad_b_seq_len, + "is_prefill": is_prefill, + } + logits = self.model.forward(**kwargs) + next_token_ids, next_token_probs = sample(logits, batch) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + output_dict = {} + new_input_ids = [] + for i, (r, all_input_ids, next_token_id, next_token_logprob) in enumerate( + zip(batch.requests, batch.all_input_ids, next_token_ids, next_token_logprobs) + ): + # all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device="cuda") + all_input_ids.append(int(next_token_id)) + # all_input_ids_tensor = None + new_input_ids.append(next_token_id) + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] += 1 + batch.out_token_id_counts[i][next_token_id] += 1 + metadata = { + "id": int(next_token_id), + "logprob": float(next_token_logprob), + } + output_dict[r["request_id"]] = (int(next_token_id), metadata) + + batch.input_ids = torch.tensor(new_input_ids, dtype=torch.long).cuda() + batch.nopad_b_start_loc = batch.nopad_b_start_loc + torch.arange( + 0, len(batch), dtype=torch.int32, device="cuda" + ) + batch.nopad_total_token_num += len(batch) + batch.nopad_max_len_in_batch += 1 + batch.nopad_b_seq_len += 1 + self.cache[batch.batch_id] = batch + return output_dict + + def _prefill_batch(self, batch_id): + return self.forward(batch_id, is_prefill=True) + + def _decode_batch(self, batch_id): + return self.forward(batch_id, is_prefill=False) + # might want to create a sequence pool # add a single request/sequence/input text at a time and record its length # In other words, store the actual length of input tokens representing a single input text diff --git a/colossalai/inference/tensor_parallel/modeling/__init__.py b/colossalai/inference/tensor_parallel/modeling/__init__.py index 279b54065eed..4662368b17b4 100644 --- a/colossalai/inference/tensor_parallel/modeling/__init__.py +++ b/colossalai/inference/tensor_parallel/modeling/__init__.py @@ -1,5 +1,3 @@ -import _utils - from .bloom import BloomInferenceForwards from .chatglm2 import ChatGLM2InferenceForwards from .llama import LlamaInferenceForwards diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 87ea9cf6536e..7cb71a1c1061 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -1,29 +1,31 @@ -try: - import triton +# try: - HAS_TRITON = True +import triton - from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd - from .copy_kv_cache_dest import copy_kv_cache_to_dest - from .fused_layernorm import layer_norm - from .gptq_triton import gptq_fused_linear_triton - from .rms_norm import rmsnorm_forward - from .rotary_embedding_kernel import rotary_embedding_fwd - from .softmax import softmax - from .token_attention_kernel import token_attention_fwd +HAS_TRITON = True - __all__ = [ - "llama_context_attn_fwd", - "bloom_context_attn_fwd", - "softmax", - "layer_norm", - "rmsnorm_forward", - "copy_kv_cache_to_dest", - "rotary_embedding_fwd", - "token_attention_fwd", - "gptq_fused_linear_triton", - ] +from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd +from .copy_kv_cache_dest import copy_kv_cache_to_dest +from .fused_layernorm import layer_norm +from .gptq_triton import gptq_fused_linear_triton +from .rms_norm import rmsnorm_forward +from .rotary_embedding_kernel import rotary_embedding_fwd +from .softmax import softmax +from .token_attention_kernel import token_attention_fwd -except ImportError: - HAS_TRITON = False - print("Triton is not installed. Please install Triton to use Triton kernels.") +__all__ = [ + "llama_context_attn_fwd", + "bloom_context_attn_fwd", + "softmax", + "layer_norm", + "rmsnorm_forward", + "copy_kv_cache_to_dest", + "rotary_embedding_fwd", + "token_attention_fwd", + "gptq_fused_linear_triton", +] + +# except ImportError: +# print(ImportError.msg) +# HAS_TRITON = False +# print("Triton is not installed. Please install Triton to use Triton kernels.") diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 43952e6998cf..458d7bc81ce8 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -from .torchrec import * +# from .torchrec import * From 3fa9bf08808fb0d80dbd96135f42f7b9e10df608 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 28 Sep 2023 02:07:22 +0800 Subject: [PATCH 03/12] first --- .../inference/dynamic_batching/infer_batch.py | 21 +-- .../inference/dynamic_batching/req_queue.py | 11 +- colossalai/inference/manager.py | 132 ++++++++---------- .../inference/tensor_parallel/engine.py | 49 +++++-- .../tensor_parallel/modeling/llama.py | 22 ++- .../kernel/triton/copy_kv_cache_dest.py | 1 - tests/kit/model_zoo/transformers/llama.py | 12 +- tests/test_infer/test_chatglm2_infer.py | 2 +- tests/test_infer/test_llama_infer.py | 45 +++--- 9 files changed, 165 insertions(+), 130 deletions(-) diff --git a/colossalai/inference/dynamic_batching/infer_batch.py b/colossalai/inference/dynamic_batching/infer_batch.py index b6781e0347b7..4ba3024d0d02 100644 --- a/colossalai/inference/dynamic_batching/infer_batch.py +++ b/colossalai/inference/dynamic_batching/infer_batch.py @@ -5,7 +5,10 @@ import numpy as np import torch from lightllm.common.configs.config import setting -from lightllm.common.mem_manager import MemoryManager + +from colossalai.inference.tensor_parallel import MemoryManager + +# make batch infer state an attr of InferBatch class InferSamplingParams: @@ -44,19 +47,17 @@ class InferBatch: out_token_id_counts: List sampling_param_list: List[InferSamplingParams] - input_ids: torch.Tensor - nopad_total_token_num: int nopad_max_len_in_batch: int nopad_b_loc: torch.Tensor nopad_b_start_loc: torch.Tensor nopad_b_seq_len: torch.Tensor - mem_manager: MemoryManager + cache_manager: MemoryManager @classmethod @torch.no_grad() def init_batch( - cls, batch_id, requests, dtype: torch.dtype, device: torch.device, mem_manager: MemoryManager, vocab_size: int + cls, batch_id, requests, dtype: torch.dtype, device: torch.device, cache_manager: MemoryManager, vocab_size: int ): input_lengths = [] all_input_ids = [] @@ -112,7 +113,7 @@ def init_batch( nopad_b_seq_len=nopad_b_seq_len, out_token_id_counts=out_token_id_counts, sampling_param_list=sampling_param_list, - mem_manager=mem_manager, + cache_manager=cache_manager, ) @torch.no_grad() @@ -127,7 +128,7 @@ def free_self(self): ] ) remove_index = torch.cat(remove_index, dim=-1) - self.mem_manager.free(remove_index) + self.cache_manager.free(remove_index) return @torch.no_grad() @@ -168,7 +169,7 @@ def filter(self, request_ids: List[int]): ) remove_index = torch.cat(remove_index, dim=-1) - self.mem_manager.free(remove_index) + self.cache_manager.free(remove_index) nopad_max_len_in_batch = 0 for i, request_id in enumerate(request_ids): @@ -207,7 +208,7 @@ def filter(self, request_ids: List[int]): nopad_b_seq_len=nopad_b_seq_len, out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices], sampling_param_list=[self.sampling_param_list[_i] for _i in indices], - mem_manager=self.mem_manager, + cache_manager=self.cache_manager, ) @classmethod @@ -274,7 +275,7 @@ def merge(cls, batch1, batch2): nopad_b_seq_len=nopad_b_seq_len, out_token_id_counts=out_token_id_counts, sampling_param_list=sampling_param_list, - mem_manager=batches[0].mem_manager, + cache_manager=batches[0].cache_manager, ) def __len__(self): diff --git a/colossalai/inference/dynamic_batching/req_queue.py b/colossalai/inference/dynamic_batching/req_queue.py index 5982b3443838..61573486d0d1 100644 --- a/colossalai/inference/dynamic_batching/req_queue.py +++ b/colossalai/inference/dynamic_batching/req_queue.py @@ -7,12 +7,12 @@ class ReqQueue: - def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size) -> None: + def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size, waiting_req_list=[]) -> None: self.max_total_tokens = max_total_tokens assert batch_max_tokens is not None self.batch_max_tokens = batch_max_tokens self.running_max_req_size = running_max_req_size - self.waiting_req_list: List[Req] = [] + self.waiting_req_list: List[Req] = waiting_req_list def append(self, req): self.waiting_req_list.append(req) @@ -39,7 +39,8 @@ def _can_add_new_req(self, req): size_array = np.arange(1, len(self.cache_len_list) + 1, 1) need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() - if need_max_token_num < self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size: + # NOTE: change here < to <= + if need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size: return True else: return False @@ -47,16 +48,16 @@ def _can_add_new_req(self, req): def generate_new_batch(self, current_batch: Batch): if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size: return None - self._init_cache_list(current_batch) can_run_list = [] new_batch_total_tokens = 0 aborted_count = 0 for req in self.waiting_req_list: + flag = self._can_add_new_req(req) if req.aborted: aborted_count += 1 continue - if self._can_add_new_req(req) and new_batch_total_tokens + req.input_len <= self.batch_max_tokens: + if flag and new_batch_total_tokens + req.input_len <= self.batch_max_tokens: can_run_list.append(req) new_batch_total_tokens += req.input_len else: diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index e752859929a5..39d6fa269384 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,11 +1,5 @@ import argparse -import asyncio - -import uvloop - -import colossalai - -asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +import time from typing import List from dynamic_batching.infer_batch import InferBatch @@ -15,12 +9,14 @@ from dynamic_batching.stas import Stats from rpyc.utils.classic import obtain from tensor_parallel.engine import TPInferEngine +from transformers import LlamaForCausalLM, LlamaTokenizer +import colossalai from colossalai.shardformer import ShardConfig -from tests.kit.model_zoo import model_zoo from tests.test_infer.test_llama_infer import init_to_get_rotary +# faulthandler.enable() class DynamicBatchManager: def __init__( self, @@ -30,18 +26,19 @@ def __init__( batch_max_tokens, running_max_req_size, eos_id, - router_port, log_stats=True, log_stats_interval=10, + running_batch: Batch = None, + waiting_req_list=[], ): self.engine = tp_engine self.world_size = world_size self.max_total_token_num = max_total_token_num - self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size) + self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list) # all the inputs should be put into req_queue - self.running_batch: Batch = None + self.running_batch: Batch = running_batch self.eos_id = eos_id self.has_wait_tokens = 0 self.max_wait_tokens = 10 @@ -53,7 +50,7 @@ def __init__( self.stats_tool = Stats(log_stats, log_stats_interval) # In Torch serve, model is initialized before manage - async def wait_to_model_ready(self): + def wait_to_model_ready(self): pass def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str): @@ -61,7 +58,7 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques self.req_queue.append(req) return - async def abort(self, request_id): + def abort(self, request_id): if self.running_batch is not None: for req in self.running_batch.reqs: if req.request_id == request_id: @@ -73,12 +70,10 @@ async def abort(self, request_id): req.aborted = True return - async def loop_for_fwd(self): - print("why here") + def loop_for_fwd(self): counter_count = 0 - while True: - print("112221121212") - await self._step() + while self.running_batch is not None or self.req_queue.waiting_req_list: + self._step() counter_count += 1 if self.running_batch is not None: if counter_count % 50 == 0: @@ -91,26 +86,28 @@ async def loop_for_fwd(self): self.stats_tool.print_stats() if self.running_batch is None: - await asyncio.sleep(0.01) # 10ms + time.sleep(10) # 10ms - async def _step(self): + def _step(self): """ handle the requests """ # 删除所有已经 finished 的 req + print("in step forward") if self.running_batch is None: new_batch = self.req_queue.generate_new_batch(self.running_batch) if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch - await self._prefill_batch(self.running_batch) + print(new_batch.reqs) + self._prefill_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens = 0 return if self.has_wait_tokens < self.max_wait_tokens: self.stats_tool.count_output_tokens(self.running_batch) - await self._decode_batch(self.running_batch) + self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 return @@ -118,21 +115,22 @@ async def _step(self): new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) - await self._prefill_batch(new_mini_batch) + self._prefill_batch(new_mini_batch) if not new_mini_batch.is_clear(): - await self._merge_batch(self.running_batch, new_mini_batch) + self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) self.has_wait_tokens = 0 else: self.stats_tool.count_output_tokens(self.running_batch) - await self._decode_batch(self.running_batch) + self._decode_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens += 1 return - async def _init_batch(self, batch: Batch): + def _init_batch(self, batch: Batch, dtype="fp16"): reqs = [r.to_rpc_obj() for r in batch.reqs] + batch_id = batch.batch_id # rets = [self.model_rpcs[tp_rank].init_batch(batch.batch_id, reqs) for tp_rank in range(self.world_size)] if self.world_size != 1: batch_id, reqs, dtype = obtain(batch_id), obtain(reqs), obtain(dtype) @@ -148,19 +146,18 @@ async def _init_batch(self, batch: Batch): reqs, dtype, torch.cuda.current_device(), - self.engine.model.mem_manager, - self.engine.model.vocab_size, + self.engine.cache_manager, + self.engine.model.config.vocab_size, ) self.engine.cache[batch_id] = batch_data return - async def _prefill_batch(self, batch): - await self._init_batch(batch) + def _prefill_batch(self, batch): + self._init_batch(batch) # rets = [self.model_rpcs[tp_rank].foward(batch.batch_id) for tp_rank in range(self.world_size)] # TODO: figure out if cache and batch id is needed rets = self.engine._prefill_batch(batch.batch_id) - ans = await asyncio.gather(*rets) - + ans = rets if self.world_size != 1: req_to_out_token_id = obtain(ans[0]) else: @@ -168,13 +165,13 @@ async def _prefill_batch(self, batch): self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) # self._send_to_detokenization_proc(batch, req_to_out_token_id) - await self._handle_finish_req(batch, has_new_finished_req) + self._handle_finish_req(batch, has_new_finished_req) return - async def _decode_batch(self, batch: Batch): + def _decode_batch(self, batch: Batch): # rets = [self.model_rpcs[tp_rank].decode_batch(batch.batch_id) for tp_rank in range(self.world_size)] rets = self.engine._decode_batch(batch.batch_id) - ans = await asyncio.gather(*rets) + ans = rets if self.world_size != 1: req_to_out_token_id = obtain(ans[0]) # gather or something else: @@ -182,10 +179,10 @@ async def _decode_batch(self, batch: Batch): self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) # self._send_to_detokenization_proc(batch, req_to_out_token_id) - await self._handle_finish_req(batch, has_new_finished_req) + self._handle_finish_req(batch, has_new_finished_req) return - async def _filter_batch(self, batch: Batch): + def _filter_batch(self, batch: Batch): batch_id = batch.batch_id req_id_list = [r.request_id for r in batch.reqs] if self.world_size != 1: @@ -194,11 +191,9 @@ async def _filter_batch(self, batch: Batch): filter_batch = batch.filter(req_id_list) del batch self.engine.cache[batch_id] = filter_batch - # rets = [self.model_rpcs[tp_rank].filter_batch(batch.batch_id, req_id_list) for tp_rank in range(self.world_size)] - # await asyncio.gather(*rets) return - async def _merge_batch(self, batch1, batch2): + def _merge_batch(self, batch1, batch2): batch1 = self.engine.cache.pop(batch1.batch_id) batch2 = self.engine.cache.pop(batch2.batch_id) # rets = [self.model_rpcs[tp_rank].merge_batch(batch1.batch_id, batch2.batch_id) for tp_rank in range(self.world_size)] @@ -210,7 +205,7 @@ async def _merge_batch(self, batch1, batch2): del batch2 return - async def _remove_batch(self, batch): + def _remove_batch(self, batch): batch = self.engine.cache.pop(batch.batch_id) batch.free_self() del batch @@ -218,13 +213,13 @@ async def _remove_batch(self, batch): # await asyncio.gather(*rets) return - async def _handle_finish_req(self, batch: Batch, has_new_finished_req): + def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: batch.filter_finished() if batch.is_clear(): - await self._remove_batch(batch) + self._remove_batch(batch) else: - await self._filter_batch(batch) + self._filter_batch(batch) return def _filter_runing_batch(self): @@ -244,7 +239,7 @@ def clean_up(self): pass -def start_router_process(args, tp_engine, router_port): +def start_router_process(args, tp_engine, waiting_req_list): try: batch_manager = DynamicBatchManager( tp_engine=tp_engine, @@ -253,52 +248,45 @@ def start_router_process(args, tp_engine, router_port): batch_max_tokens=args.batch_max_tokens, running_max_req_size=args.running_max_req_size, eos_id=args.eos_id, - router_port=router_port, log_stats=not args.disable_log_stats, log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, ) - except Exception as e: - import traceback - - err_str = "\n".join(traceback.format_exception(e)) - print(err_str) + except Exception: # may need use logger batch_manager.clean_up() raise print("start router process") - # batch_manager.loop_for_fwd() - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.create_task(batch_manager.loop_for_fwd()) - print("good") + batch_manager.loop_for_fwd() return if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--tp", type=int, help="tp_size", default=1) - parser.add_argument("--max_total_token_num", type=int, default=2048, help="max_total_token_num") - parser.add_argument("-b", "--batch_max_tokens", type=int, default=1024, help="max tokens of one batch") + parser.add_argument("--max_total_token_num", type=int, default=42, help="max_total_token_num") + parser.add_argument("-b", "--batch_max_tokens", type=int, default=42, help="max tokens of one batch") parser.add_argument("--running_max_req_size", type=int, default=2, help="max request size of running batch ") parser.add_argument("--eos_id", type=int, default=0, help="The end token of a seq") parser.add_argument("--disable_log_stats", type=bool, default=False) parser.add_argument("--log_stats_interval", type=int, default=10) args = parser.parse_args() + sampling_params = SamplingParams() - colossalai.launch(config={}, rank=0, world_size=1, host="localhost", port=8081, backend="nccl") + req1 = Req(0, [10, 10, 10, 10, 10], sampling_params) + req2 = Req(1, [10, 10, 10, 10, 10], sampling_params) + waiting_list = [] + waiting_list.append(req1) + waiting_list.append(req2) - sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_casual_lm") - for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - orig_model = model_fn() - init_to_get_rotary(orig_model.model, base=10000) - orig_model = orig_model.half() - data = data_gen_fn() - - shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) - BATCH_SIZE = 8 - MAX_INPUT_LEN = 12 - MAX_OUTPUT_LEN = 100 - infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - start_router_process(args=args, tp_engine=infer_engine, router_port=12345) + colossalai.launch(config={}, rank=0, world_size=1, host="localhost", port=8081, backend="nccl") + tokenizer = LlamaTokenizer.from_pretrained("/data/scratch/llama-7b-hf") + tokenizer.pad_token_id = tokenizer.unk_token_id + model = LlamaForCausalLM.from_pretrained("/data/scratch/llama-7b-hf", pad_token_id=tokenizer.eos_token_id) + model = model.half() + init_to_get_rotary(model.model, base=10000) + shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) + infer_engine = TPInferEngine(model, shard_config, 2, 5, 16) + start_router_process(args=args, tp_engine=infer_engine, waiting_req_list=waiting_list) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index c7c111aa4c5d..bdb3f135acd1 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn -from dynamic_batching.infer_batch import InferBatch # may intergrate with batchinfer state from transformers import BloomForCausalLM, LlamaForCausalLM from transformers.generation import GenerationConfig from transformers.generation.stopping_criteria import StoppingCriteriaList @@ -14,6 +13,8 @@ from .batch_infer_state import BatchInferState from .kvcache_manager import MemoryManager +# from dynamic_batching.infer_batch import InferBatch + DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 _supported_models = [ @@ -276,7 +277,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState: attention_mask = [attention_mask] if attention_mask is not None else attention_mask batch_size = len(input_ids_list) - + print(input_ids_list.shape) seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") start_index = 0 @@ -310,6 +311,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState: batch_infer_state.past_key_values_len = 0 batch_infer_state.is_context_stage = True batch_infer_state.set_cache_manager(self.cache_manager) + return batch_infer_state @torch.no_grad() @@ -340,6 +342,7 @@ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch elif isinstance(model, BloomForCausalLM): model = self.model.transformer setattr(model, "infer_state", batch_infer_state) + print(model.infer_state) outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False) @@ -373,19 +376,41 @@ def _update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device) infer_state.seq_len += 1 + @torch.no_grad() def forward(self, batch_id, is_prefill): - batch: InferBatch = self.cache.pop(batch_id) - kwargs = { + batch = self.cache.pop(batch_id) + print(batch) + all_input_ids = torch.tensor(batch.all_input_ids).cuda() + print(all_input_ids) + infer_state = self.prepare_batch_state(all_input_ids) + + batch_args = { "batch_size": len(batch), - "total_token_num": batch.nopad_total_token_num, "max_len_in_batch": batch.nopad_max_len_in_batch, - "input_ids": batch.input_ids, - "b_loc": batch.nopad_b_loc, - "b_start_loc": batch.nopad_b_start_loc, - "b_seq_len": batch.nopad_b_seq_len, - "is_prefill": is_prefill, + "block_loc": batch.nopad_b_loc, + "start_loc": batch.nopad_b_start_loc, + "seq_len": batch.nopad_b_seq_len, + "cache_manager": batch.cache_manager, + "is_context_stage": is_prefill, } - logits = self.model.forward(**kwargs) + + BatchInferState(**batch_args) + model = self.model + if isinstance(model, LlamaForCausalLM): + model = self.model.model + elif isinstance(model, BloomForCausalLM): + model = self.model.transformer + setattr(model, "infer_state", infer_state) + print(model.infer_state) + + position_ids = torch.arange(0, 5, dtype=torch.long, device="cuda") + position_ids = position_ids.repeat(2, 1) + self.generate(input_tokens=all_input_ids) + logits = self.model.forward(input_ids=all_input_ids, position_ids=position_ids) + + print(logits) + + print("yeah") next_token_ids, next_token_probs = sample(logits, batch) next_token_ids = next_token_ids.detach().cpu().numpy() next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() @@ -417,9 +442,11 @@ def forward(self, batch_id, is_prefill): self.cache[batch.batch_id] = batch return output_dict + @torch.no_grad() def _prefill_batch(self, batch_id): return self.forward(batch_id, is_prefill=True) + @torch.no_grad() def _decode_batch(self, batch_id): return self.forward(batch_id, is_prefill=False) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 64d6e947e924..7e463d4f64dd 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -77,6 +77,7 @@ def llama_model_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") @@ -107,7 +108,9 @@ def llama_model_forward( infer_state.init_block_loc( infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index ) + print(infer_state) else: + return infer_state.is_context_stage = False alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) if alloc_mem is not None: @@ -132,10 +135,12 @@ def llama_model_forward( position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) + position_ids = position_ids.repeat(batch_size, 1) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() - + print(position_ids.shape) + print(position_ids) if infer_state.is_context_stage: infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( position_ids.view(-1).shape[0], -1 @@ -143,6 +148,7 @@ def llama_model_forward( infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( position_ids.view(-1).shape[0], -1 ) + else: seq_len = infer_state.seq_len infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) @@ -169,7 +175,6 @@ def llama_model_forward( next_decoder_cache = () if use_cache else None infer_state.decode_layer_id = 0 - for idx, decoder_layer in enumerate(self.layers): past_key_value = past_key_values[idx] if past_key_values is not None else None # NOTE: modify here for passing args to decoder layer @@ -220,7 +225,7 @@ def llama_decoder_layer_forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - + print(hidden_states.shape) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, @@ -281,12 +286,16 @@ def llama_flash_attn_kvcache_forward( cos, sin = infer_state.position_cos, infer_state.position_sin # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) - + print(self.num_heads, self.head_dim) + # 改 rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + print(key_buffer.shape) + print("shape of mem key buffer", mem_manager.key_buffer[layer_id].shape) copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) return @@ -296,7 +305,7 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, if infer_state.is_context_stage: # first token generation - + print(infer_state.decode_layer_id) # copy key and value calculated in current step to memory manager _copy_kv_to_mem_cache( infer_state.decode_layer_id, @@ -305,6 +314,7 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, infer_state.context_mem_index, infer_state.cache_manager, ) + print("yeah") attn_output = torch.empty_like(query_states) @@ -317,6 +327,8 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, infer_state.seq_len, infer_state.cache_manager.past_key_values_length, ) + print("yeah") + else: if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py index 02edcc9a903a..0520bc111384 100644 --- a/colossalai/kernel/triton/copy_kv_cache_dest.py +++ b/colossalai/kernel/triton/copy_kv_cache_dest.py @@ -51,7 +51,6 @@ def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out" num_warps = 2 - _fwd_copy_kv_cache_dest[(seq_len,)]( k_ptr, dest_index_ptr, diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index bc229b17e08c..cd23a36cd3d3 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -27,8 +27,10 @@ def data_gen(): # tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') # ----------------------------------- - input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long() - attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long() + input_ids = torch.Tensor( + [[1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082]] + ).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long() return dict(input_ids=input_ids, attention_mask=attention_mask) # label is needed for casual lm @@ -47,10 +49,10 @@ def data_gen_for_casual_lm(): loss_fn_for_seq_classification = lambda output: output.logits.mean() config = LlamaConfig( - num_hidden_layers=4, - hidden_size=128, + num_hidden_layers=8, + hidden_size=4096, intermediate_size=256, - num_attention_heads=4, + num_attention_heads=32, max_position_embeddings=128, num_labels=16, ) diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index 699ba7b52fe0..551f36770327 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -42,7 +42,7 @@ def run_chatglm2_test(test_config): enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True ) infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - + print(input_ids) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) outputs = infer_engine.generate(input_ids, **generate_kwargs) assert outputs is not None diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 0e5efe68508a..fc329ce5c958 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -3,18 +3,18 @@ import pytest import torch from packaging import version +from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" -TPSIZE = 2 -BATCH_SIZE = 8 -MAX_INPUT_LEN = 12 +TPSIZE = 1 +BATCH_SIZE = 2 +MAX_INPUT_LEN = 8 MAX_OUTPUT_LEN = 100 CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") @@ -53,22 +53,27 @@ def init_to_get_rotary(self, base=10000): ], ) def run_llama_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_casual_lm") - for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - orig_model = model_fn() - init_to_get_rotary(orig_model.model, base=10000) - orig_model = orig_model.half() - data = data_gen_fn() - - shard_config = ShardConfig( - enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True - ) - infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - - generate_kwargs = dict(do_sample=False) - outputs = infer_engine.generate(data, **generate_kwargs) - - assert outputs is not None + tokenizer = LlamaTokenizer.from_pretrained("/data/scratch/llama-7b-hf") + tokenizer.pad_token_id = tokenizer.unk_token_id + model = LlamaForCausalLM.from_pretrained("/data/scratch/llama-7b-hf", pad_token_id=tokenizer.eos_token_id) + init_to_get_rotary(model.model, base=10000) + model = model.half() + + shard_config = ShardConfig( + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + ) + infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + + generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) + input_tokens = { + "input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + "attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + } + torch.cuda.synchronize() + + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + + assert outputs is not None def check_llama(rank, world_size, port): From e317b79afc8718b19a5e7b4f35fff0162d0e67c6 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 28 Sep 2023 19:37:01 +0800 Subject: [PATCH 04/12] fix --- .../inference/dynamic_batching/io_struct.py | 61 +++++++++++-------- colossalai/inference/manager.py | 13 ++-- .../inference/tensor_parallel/engine.py | 39 ++++++------ .../tensor_parallel/modeling/llama.py | 18 +----- 4 files changed, 62 insertions(+), 69 deletions(-) diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 5324ee262986..cd6fbd93ad42 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -1,6 +1,6 @@ +from typing import Dict, List, Tuple + from .sampling_params import SamplingParams -from typing import Dict, List, Optional, Tuple -import asyncio class Req: @@ -16,30 +16,35 @@ def __init__(self, request_id, prompt_ids, sample_params: SamplingParams): self.aborted = False def to_rpc_obj(self): - return {"request_id": self.request_id, - "input_id": self.prompt_ids, - "output_len": self.max_output_len, - "sampling_param": self.sample_params.to_dict() } + return { + "request_id": self.request_id, + "input_id": self.prompt_ids, + "output_len": self.max_output_len, + "sampling_param": self.sample_params.to_dict(), + } def to_req_detokenization_state(self): - out = ReqDetokenizationState(self.request_id, self.prompt_ids, self.max_output_len, self.sample_params.ignore_eos) + out = ReqDetokenizationState( + self.request_id, self.prompt_ids, self.max_output_len, self.sample_params.ignore_eos + ) if self.output_metadata_list: out.gen_metadata.update(self.output_metadata_list[-1]) return out - + def stop_sequences_matched(self): - for stop_token_ids in self.sample_params.stop_sequences: - stop_len = len(stop_token_ids) - if stop_len > 0: - if len(self.output_ids) >= stop_len: - if all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)): - return True + # should we add stpp sequences to the sample params? + if self.sample_params.stop_sequences is not None: + for stop_token_ids in self.sample_params.stop_sequences: + stop_len = len(stop_token_ids) + if stop_len > 0: + if len(self.output_ids) >= stop_len: + if all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)): + return True return False def __repr__(self): - return (f"request_id(n={self.request_id}, " - f"prompt_ids={self.prompt_ids}, ") - + return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, " + class ReqDetokenizationState: def __init__( @@ -60,6 +65,7 @@ def __init__( self.ignore_eos = ignore_eos self.gen_metadata = {} + class Batch: def __init__(self, batch_id, reqs: List[Req]): self.batch_id = batch_id @@ -77,7 +83,7 @@ def calcu_max_tokens(self): for req in self.reqs: tokens += req.input_len + req.max_output_len return tokens - + def calcu_used_tokens(self): tokens = 0 for req in self.reqs: @@ -116,18 +122,23 @@ def merge(self, mini_batch): return def __repr__(self): - return (f"batch_id={self.batch_id}, " - f"reqs={self.reqs}, ") - + return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, " + + class BatchTokenIdOut: def __init__(self): - self.reqs_infs: List[Tuple[str, int, Dict, bool, bool]] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state] + self.reqs_infs: List[ + Tuple[str, int, Dict, bool, bool] + ] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state] + class BatchStrOut: def __init__(self): - self.reqs_infs: List[Tuple[str, str, Dict, bool, bool]] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state] - + self.reqs_infs: List[ + Tuple[str, str, Dict, bool, bool] + ] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state] + + class AbortReq: def __init__(self, req_id): self.req_id = req_id - diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 39d6fa269384..02ae2da611a2 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -93,13 +93,11 @@ def _step(self): handle the requests """ # 删除所有已经 finished 的 req - print("in step forward") if self.running_batch is None: new_batch = self.req_queue.generate_new_batch(self.running_batch) if new_batch is not None: self.stats_tool.count_prompt_tokens(new_batch) self.running_batch = new_batch - print(new_batch.reqs) self._prefill_batch(self.running_batch) self._filter_runing_batch() self.has_wait_tokens = 0 @@ -158,10 +156,11 @@ def _prefill_batch(self, batch): # TODO: figure out if cache and batch id is needed rets = self.engine._prefill_batch(batch.batch_id) ans = rets + # ans should be a dict if self.world_size != 1: - req_to_out_token_id = obtain(ans[0]) + req_to_out_token_id = obtain(ans) else: - req_to_out_token_id = ans[0] + req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) # self._send_to_detokenization_proc(batch, req_to_out_token_id) @@ -173,9 +172,9 @@ def _decode_batch(self, batch: Batch): rets = self.engine._decode_batch(batch.batch_id) ans = rets if self.world_size != 1: - req_to_out_token_id = obtain(ans[0]) # gather or something + req_to_out_token_id = obtain(ans) # gather or something else: - req_to_out_token_id = ans[0] + req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) # self._send_to_detokenization_proc(batch, req_to_out_token_id) @@ -281,7 +280,7 @@ def start_router_process(args, tp_engine, waiting_req_list): waiting_list.append(req1) waiting_list.append(req2) - colossalai.launch(config={}, rank=0, world_size=1, host="localhost", port=8081, backend="nccl") + colossalai.launch(config={}, rank=0, world_size=1, host="localhost", port=8082, backend="nccl") tokenizer = LlamaTokenizer.from_pretrained("/data/scratch/llama-7b-hf") tokenizer.pad_token_id = tokenizer.unk_token_id model = LlamaForCausalLM.from_pretrained("/data/scratch/llama-7b-hf", pad_token_id=tokenizer.eos_token_id) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index bdb3f135acd1..d30d05014dfa 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -277,7 +277,6 @@ def prepare_batch_state(self, inputs) -> BatchInferState: attention_mask = [attention_mask] if attention_mask is not None else attention_mask batch_size = len(input_ids_list) - print(input_ids_list.shape) seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") start_index = 0 @@ -342,7 +341,6 @@ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch elif isinstance(model, BloomForCausalLM): model = self.model.transformer setattr(model, "infer_state", batch_infer_state) - print(model.infer_state) outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False) @@ -379,10 +377,10 @@ def _update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: @torch.no_grad() def forward(self, batch_id, is_prefill): batch = self.cache.pop(batch_id) - print(batch) - all_input_ids = torch.tensor(batch.all_input_ids).cuda() - print(all_input_ids) - infer_state = self.prepare_batch_state(all_input_ids) + if is_prefill: + input_ = torch.tensor(batch.all_input_ids).cuda() + else: + input_ = batch.input_ids.reshape(len(batch), 1) batch_args = { "batch_size": len(batch), @@ -394,33 +392,34 @@ def forward(self, batch_id, is_prefill): "is_context_stage": is_prefill, } - BatchInferState(**batch_args) + infer_state = BatchInferState(**batch_args) model = self.model if isinstance(model, LlamaForCausalLM): model = self.model.model elif isinstance(model, BloomForCausalLM): model = self.model.transformer setattr(model, "infer_state", infer_state) - print(model.infer_state) - - position_ids = torch.arange(0, 5, dtype=torch.long, device="cuda") - position_ids = position_ids.repeat(2, 1) - self.generate(input_tokens=all_input_ids) - logits = self.model.forward(input_ids=all_input_ids, position_ids=position_ids) - print(logits) + output = self.model.forward(input_ids=input_) + logits = output.logits + prob_out = torch.softmax(logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=-1, keepdim=True) + prob_out = torch.log(prob_out).detach().cpu().numpy() + predict_ids = predict_ids.detach().cpu().numpy() + # [batch_size, 1 ,vocab_size] + # next_token_ids, next_token_probs = sample(logits, batch) + # next_token_ids = next_token_ids.detach().cpu().numpy() + # next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - print("yeah") - next_token_ids, next_token_probs = sample(logits, batch) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() output_dict = {} new_input_ids = [] for i, (r, all_input_ids, next_token_id, next_token_logprob) in enumerate( - zip(batch.requests, batch.all_input_ids, next_token_ids, next_token_logprobs) + zip(batch.requests, batch.all_input_ids, predict_ids, prob_out) ): + next_token_id = int(next_token_id[0]) + next_token_logprob = next_token_logprob[0][next_token_id] # all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device="cuda") - all_input_ids.append(int(next_token_id)) + all_input_ids.append(next_token_id) # all_input_ids_tensor = None new_input_ids.append(next_token_id) batch.all_input_ids[i] = all_input_ids diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 7e463d4f64dd..7a31802e00f6 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -108,9 +108,7 @@ def llama_model_forward( infer_state.init_block_loc( infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index ) - print(infer_state) else: - return infer_state.is_context_stage = False alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) if alloc_mem is not None: @@ -139,8 +137,6 @@ def llama_model_forward( position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() - print(position_ids.shape) - print(position_ids) if infer_state.is_context_stage: infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( position_ids.view(-1).shape[0], -1 @@ -225,7 +221,6 @@ def llama_decoder_layer_forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - print(hidden_states.shape) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, @@ -285,17 +280,12 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager.past_key_values_length += q_len # seq_len cos, sin = infer_state.position_cos, infer_state.position_sin - # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) - print(self.num_heads, self.head_dim) - # 改 + rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): - print(key_buffer.shape) - print("shape of mem key buffer", mem_manager.key_buffer[layer_id].shape) copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) - copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) return @@ -305,7 +295,6 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, if infer_state.is_context_stage: # first token generation - print(infer_state.decode_layer_id) # copy key and value calculated in current step to memory manager _copy_kv_to_mem_cache( infer_state.decode_layer_id, @@ -314,10 +303,7 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, infer_state.context_mem_index, infer_state.cache_manager, ) - print("yeah") - attn_output = torch.empty_like(query_states) - llama_context_attn_fwd( query_states, key_states, @@ -327,8 +313,6 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, infer_state.seq_len, infer_state.cache_manager.past_key_values_length, ) - print("yeah") - else: if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly From ac38fae34b2ee5a7a98d85652f0db935ec1ccd50 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Sun, 8 Oct 2023 18:12:34 +0800 Subject: [PATCH 05/12] fix dynamic batching --- .../inference/dynamic_batching/infer_batch.py | 21 ++- colossalai/inference/manager.py | 151 ++++++------------ .../inference/tensor_parallel/engine.py | 26 +-- .../tensor_parallel/kvcache_manager.py | 2 + .../tensor_parallel/modeling/llama.py | 15 +- tests/kit/model_zoo/torchrec/__init__.py | 2 +- .../test_dynamic_batching/test_forward.py | 72 +++++++++ 7 files changed, 160 insertions(+), 129 deletions(-) create mode 100644 tests/test_infer/test_dynamic_batching/test_forward.py diff --git a/colossalai/inference/dynamic_batching/infer_batch.py b/colossalai/inference/dynamic_batching/infer_batch.py index 4ba3024d0d02..244f8fd560d6 100644 --- a/colossalai/inference/dynamic_batching/infer_batch.py +++ b/colossalai/inference/dynamic_batching/infer_batch.py @@ -4,7 +4,6 @@ import numpy as np import torch -from lightllm.common.configs.config import setting from colossalai.inference.tensor_parallel import MemoryManager @@ -53,11 +52,19 @@ class InferBatch: nopad_b_start_loc: torch.Tensor nopad_b_seq_len: torch.Tensor cache_manager: MemoryManager + max_total_len: int @classmethod @torch.no_grad() def init_batch( - cls, batch_id, requests, dtype: torch.dtype, device: torch.device, cache_manager: MemoryManager, vocab_size: int + cls, + batch_id, + requests, + dtype: torch.dtype, + device: torch.device, + cache_manager: MemoryManager, + vocab_size: int, + max_total_len: int, ): input_lengths = [] all_input_ids = [] @@ -68,7 +75,7 @@ def init_batch( nopad_total_token_num = 0 nopad_max_len_in_batch = 0 - nopad_b_loc = torch.empty((len(requests), setting["max_req_total_len"] + 12), dtype=torch.long, device="cuda") + nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda") nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda") for i, r in enumerate(requests): # request id -> idx in list mapping @@ -91,6 +98,7 @@ def init_batch( nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda") nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] + if len(requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) else: @@ -114,6 +122,7 @@ def init_batch( out_token_id_counts=out_token_id_counts, sampling_param_list=sampling_param_list, cache_manager=cache_manager, + max_total_len=max_total_len, ) @torch.no_grad() @@ -145,9 +154,7 @@ def filter(self, request_ids: List[int]): nopad_total_token_num = 0 nopad_max_len_in_batch = 0 - nopad_b_loc = torch.empty( - (len(request_ids), setting["max_req_total_len"] + 12), dtype=torch.long, device="cuda" - ) + nopad_b_loc = torch.empty((len(request_ids), self.max_total_len + 12), dtype=torch.long, device="cuda") nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda") nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda") @@ -228,7 +235,7 @@ def merge(cls, batch1, batch2): nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2.nopad_max_len_in_batch) - nopad_b_loc = torch.empty((new_batch_size, setting["max_req_total_len"] + 12), dtype=torch.long, device="cuda") + nopad_b_loc = torch.empty((new_batch_size, batch1.max_total_len + 12), dtype=torch.long, device="cuda") nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") nopad_start_loc_len_temp = 0 diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 02ae2da611a2..4153e18aa8af 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,59 +1,54 @@ -import argparse import time from typing import List -from dynamic_batching.infer_batch import InferBatch -from dynamic_batching.io_struct import Batch, Req -from dynamic_batching.req_queue import ReqQueue -from dynamic_batching.sampling_params import SamplingParams -from dynamic_batching.stas import Stats -from rpyc.utils.classic import obtain -from tensor_parallel.engine import TPInferEngine -from transformers import LlamaForCausalLM, LlamaTokenizer +from .dynamic_batching.infer_batch import InferBatch +from .dynamic_batching.io_struct import Batch, Req +from .dynamic_batching.req_queue import ReqQueue +from .dynamic_batching.sampling_params import SamplingParams +from .dynamic_batching.stas import Stats +from .tensor_parallel import TPInferEngine -import colossalai -from colossalai.shardformer import ShardConfig -from tests.test_infer.test_llama_infer import init_to_get_rotary - -# faulthandler.enable() class DynamicBatchManager: def __init__( self, tp_engine: TPInferEngine, - world_size, max_total_token_num, batch_max_tokens, - running_max_req_size, eos_id, log_stats=True, log_stats_interval=10, running_batch: Batch = None, waiting_req_list=[], ): + """ + Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager + max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) + batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests + running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine + eos_id : The end token of a seq + log_stats : whether to log stats + log_stats_interval : log stats interval + running_batch : running batch + waiting_req_list : list of waiting requests, initialized before dynamic batch manager + """ self.engine = tp_engine - self.world_size = world_size self.max_total_token_num = max_total_token_num - + running_max_req_size = self.engine.max_batch_size self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list) - # all the inputs should be put into req_queue + # all the inputs should be put into req_queue: waiting req list self.running_batch: Batch = running_batch self.eos_id = eos_id self.has_wait_tokens = 0 self.max_wait_tokens = 10 - # context = zmq.asyncio.Context(2) - # self.send_to_detokenization = context.socket(zmq.PUSH) - # self.send_to_detokenization.connect(f"tcp://127.0.0.1:{detokenization_port}") - self.stats_tool = Stats(log_stats, log_stats_interval) - # In Torch serve, model is initialized before manage - def wait_to_model_ready(self): - pass - def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str): + """ + Add new request to req queue, during initialization all requests are held in waiting list. + """ req = Req(request_id, prompt_ids, sampling_params) self.req_queue.append(req) return @@ -71,6 +66,9 @@ def abort(self, request_id): return def loop_for_fwd(self): + """ + The main loop for a dynamic batching process. + """ counter_count = 0 while self.running_batch is not None or self.req_queue.waiting_req_list: self._step() @@ -86,13 +84,13 @@ def loop_for_fwd(self): self.stats_tool.print_stats() if self.running_batch is None: - time.sleep(10) # 10ms + time.sleep(0.1) # 10ms def _step(self): """ - handle the requests + Logic for handling requests """ - # 删除所有已经 finished 的 req + if self.running_batch is None: new_batch = self.req_queue.generate_new_batch(self.running_batch) if new_batch is not None: @@ -129,16 +127,14 @@ def _step(self): def _init_batch(self, batch: Batch, dtype="fp16"): reqs = [r.to_rpc_obj() for r in batch.reqs] batch_id = batch.batch_id - # rets = [self.model_rpcs[tp_rank].init_batch(batch.batch_id, reqs) for tp_rank in range(self.world_size)] - if self.world_size != 1: - batch_id, reqs, dtype = obtain(batch_id), obtain(reqs), obtain(dtype) + import torch if dtype == "fp16": dtype = torch.float16 else: assert False, "error dtype" - # cache may be removed + batch_data = InferBatch.init_batch( batch_id, reqs, @@ -146,71 +142,60 @@ def _init_batch(self, batch: Batch, dtype="fp16"): torch.cuda.current_device(), self.engine.cache_manager, self.engine.model.config.vocab_size, + self.engine.max_input_len + self.engine.max_output_len, ) self.engine.cache[batch_id] = batch_data - return def _prefill_batch(self, batch): + """ + For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. + """ self._init_batch(batch) - # rets = [self.model_rpcs[tp_rank].foward(batch.batch_id) for tp_rank in range(self.world_size)] # TODO: figure out if cache and batch id is needed - rets = self.engine._prefill_batch(batch.batch_id) - ans = rets - # ans should be a dict - if self.world_size != 1: - req_to_out_token_id = obtain(ans) - else: - req_to_out_token_id = ans + ans = self.engine._prefill_batch(batch.batch_id) + req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - # self._send_to_detokenization_proc(batch, req_to_out_token_id) self._handle_finish_req(batch, has_new_finished_req) - return + # delete finished reqs def _decode_batch(self, batch: Batch): - # rets = [self.model_rpcs[tp_rank].decode_batch(batch.batch_id) for tp_rank in range(self.world_size)] - rets = self.engine._decode_batch(batch.batch_id) - ans = rets - if self.world_size != 1: - req_to_out_token_id = obtain(ans) # gather or something - else: - req_to_out_token_id = ans + """ + Decoding process + """ + ans = self.engine._decode_batch(batch.batch_id) + req_to_out_token_id = ans self._add_token_id_to_req(batch, req_to_out_token_id) has_new_finished_req = batch.mark_finished_req(self.eos_id) - # self._send_to_detokenization_proc(batch, req_to_out_token_id) self._handle_finish_req(batch, has_new_finished_req) - return def _filter_batch(self, batch: Batch): batch_id = batch.batch_id req_id_list = [r.request_id for r in batch.reqs] - if self.world_size != 1: - batch_id, req_id_list = obtain(batch_id), obtain(req_id_list) batch = self.engine.cache.pop(batch_id) filter_batch = batch.filter(req_id_list) del batch self.engine.cache[batch_id] = filter_batch - return def _merge_batch(self, batch1, batch2): + """ + Merge new mini batch into running batch. + """ batch1 = self.engine.cache.pop(batch1.batch_id) batch2 = self.engine.cache.pop(batch2.batch_id) - # rets = [self.model_rpcs[tp_rank].merge_batch(batch1.batch_id, batch2.batch_id) for tp_rank in range(self.world_size)] - # await asyncio.gather(*rets) m_batch = InferBatch.merge(batch1, batch2) self.engine.cache[batch1.batch_id] = m_batch del batch1 del batch2 - return def _remove_batch(self, batch): + """ + Remove finished batch. + """ batch = self.engine.cache.pop(batch.batch_id) batch.free_self() del batch - # rets = [self.model_rpcs[tp_rank].remove_batch(batch.batch_id) for tp_rank in range(self.world_size)] - # await asyncio.gather(*rets) - return def _handle_finish_req(self, batch: Batch, has_new_finished_req): if has_new_finished_req: @@ -219,12 +204,10 @@ def _handle_finish_req(self, batch: Batch, has_new_finished_req): self._remove_batch(batch) else: self._filter_batch(batch) - return def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): self.running_batch = None - return def _add_token_id_to_req(self, batch: Batch, req_ans): for req_id, (new_token_id, new_gen_metadata) in req_ans.items(): @@ -234,18 +217,16 @@ def _add_token_id_to_req(self, batch: Batch, req_ans): return def clean_up(self): - # this logic should be implemented + # this logic should be implemented in the future. pass -def start_router_process(args, tp_engine, waiting_req_list): +def start_dynamic_batching(args, tp_engine, waiting_req_list): try: batch_manager = DynamicBatchManager( tp_engine=tp_engine, - world_size=args.tp, max_total_token_num=args.max_total_token_num, batch_max_tokens=args.batch_max_tokens, - running_max_req_size=args.running_max_req_size, eos_id=args.eos_id, log_stats=not args.disable_log_stats, log_stats_interval=args.log_stats_interval, @@ -253,39 +234,9 @@ def start_router_process(args, tp_engine, waiting_req_list): ) except Exception: - # may need use logger batch_manager.clean_up() raise - print("start router process") + print("start dynamic batching process") batch_manager.loop_for_fwd() return - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--tp", type=int, help="tp_size", default=1) - parser.add_argument("--max_total_token_num", type=int, default=42, help="max_total_token_num") - parser.add_argument("-b", "--batch_max_tokens", type=int, default=42, help="max tokens of one batch") - parser.add_argument("--running_max_req_size", type=int, default=2, help="max request size of running batch ") - parser.add_argument("--eos_id", type=int, default=0, help="The end token of a seq") - parser.add_argument("--disable_log_stats", type=bool, default=False) - parser.add_argument("--log_stats_interval", type=int, default=10) - args = parser.parse_args() - sampling_params = SamplingParams() - - req1 = Req(0, [10, 10, 10, 10, 10], sampling_params) - req2 = Req(1, [10, 10, 10, 10, 10], sampling_params) - waiting_list = [] - waiting_list.append(req1) - waiting_list.append(req2) - - colossalai.launch(config={}, rank=0, world_size=1, host="localhost", port=8082, backend="nccl") - tokenizer = LlamaTokenizer.from_pretrained("/data/scratch/llama-7b-hf") - tokenizer.pad_token_id = tokenizer.unk_token_id - model = LlamaForCausalLM.from_pretrained("/data/scratch/llama-7b-hf", pad_token_id=tokenizer.eos_token_id) - model = model.half() - init_to_get_rotary(model.model, base=10000) - shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) - infer_engine = TPInferEngine(model, shard_config, 2, 5, 16) - start_router_process(args=args, tp_engine=infer_engine, waiting_req_list=waiting_list) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index d30d05014dfa..a55ddabe9098 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -376,6 +376,9 @@ def _update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: @torch.no_grad() def forward(self, batch_id, is_prefill): + """ + Forward is used in Dynamic Batching + """ batch = self.cache.pop(batch_id) if is_prefill: input_ = torch.tensor(batch.all_input_ids).cuda() @@ -402,22 +405,27 @@ def forward(self, batch_id, is_prefill): output = self.model.forward(input_ids=input_) logits = output.logits - prob_out = torch.softmax(logits, dim=-1) + # bsz, seq_len, vocab_size + prob_out = torch.softmax( + logits[ + :, + -1, + ], + dim=-1, + ).squeeze(1) + # prob_out: bsz, vocab_size predict_ids = torch.argmax(prob_out, dim=-1, keepdim=True) prob_out = torch.log(prob_out).detach().cpu().numpy() predict_ids = predict_ids.detach().cpu().numpy() - # [batch_size, 1 ,vocab_size] - # next_token_ids, next_token_probs = sample(logits, batch) - # next_token_ids = next_token_ids.detach().cpu().numpy() - # next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() + # [ batch_size, 1 ] output_dict = {} new_input_ids = [] for i, (r, all_input_ids, next_token_id, next_token_logprob) in enumerate( zip(batch.requests, batch.all_input_ids, predict_ids, prob_out) ): - next_token_id = int(next_token_id[0]) - next_token_logprob = next_token_logprob[0][next_token_id] + next_token_id = int(next_token_id) + next_token_logprob = next_token_logprob[next_token_id] # all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device="cuda") all_input_ids.append(next_token_id) # all_input_ids_tensor = None @@ -432,12 +440,8 @@ def forward(self, batch_id, is_prefill): output_dict[r["request_id"]] = (int(next_token_id), metadata) batch.input_ids = torch.tensor(new_input_ids, dtype=torch.long).cuda() - batch.nopad_b_start_loc = batch.nopad_b_start_loc + torch.arange( - 0, len(batch), dtype=torch.int32, device="cuda" - ) batch.nopad_total_token_num += len(batch) batch.nopad_max_len_in_batch += 1 - batch.nopad_b_seq_len += 1 self.cache[batch.batch_id] = batch return output_dict diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py index e74a3a491a7b..dd6f9d7c8fc6 100644 --- a/colossalai/inference/tensor_parallel/kvcache_manager.py +++ b/colossalai/inference/tensor_parallel/kvcache_manager.py @@ -94,6 +94,8 @@ def free(self, free_index): """free memory by updating memory states based on given indexes""" self.available_size += free_index.shape[0] self.mem_state[free_index] = 1 + # FIXME: should this be zero? + self.past_key_values_length = 0 @torch.no_grad() def free_all(self): diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 7a31802e00f6..15955aee369a 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -71,8 +71,6 @@ def llama_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): - batch_size = input_ids.shape[0] # input_ids.shape[0] - infer_state = self.infer_state return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -89,14 +87,11 @@ def llama_model_forward( raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values is not None: - # NOT READY FOR PRIME TIME - # dummy but work, revise it - past_key_values_length = infer_state.cache_manager.past_key_values_length - # past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length + # NOT READY FOR PRIME TIME + # dummy but work, revise it + past_key_values_length = infer_state.cache_manager.past_key_values_length + seq_length_with_past = seq_length_with_past + past_key_values_length # NOTE: differentiate with prefill stage # block_loc require different value-assigning method for two different stage @@ -194,7 +189,7 @@ def llama_model_forward( # update indices # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 if not return_dict: diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 458d7bc81ce8..43952e6998cf 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -# from .torchrec import * +from .torchrec import * diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py new file mode 100644 index 000000000000..57cdd831c685 --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -0,0 +1,72 @@ +import argparse + +import pytest +import torch +from packaging import version +from transformers import LlamaForCausalLM, LlamaTokenizer + +import colossalai +from colossalai.inference.dynamic_batching.io_struct import Req +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.manager import start_dynamic_batching +from colossalai.inference.tensor_parallel import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.test_infer.test_llama_infer import init_to_get_rotary + +TP_SIZE = 2 +MAX_BATCH_SIZE = 2 +MAX_INPUT_LEN = 5 +MAX_OUTPUT_LEN = 16 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + + +def run(): + parser = argparse.ArgumentParser() + parser.add_argument("--max_total_token_num", type=int, default=42, help="max_total_token_num") + parser.add_argument("-b", "--batch_max_tokens", type=int, default=42, help="max tokens of one batch") + parser.add_argument("--eos_id", type=int, default=0, help="The end token of a seq") + parser.add_argument("--disable_log_stats", type=bool, default=False) + parser.add_argument("--log_stats_interval", type=int, default=10) + args = parser.parse_args() + sampling_params = SamplingParams() + + req1 = Req(0, [0, 0, 10, 6, 8], sampling_params) + req2 = Req(1, [10, 10, 10, 10, 10], sampling_params) + req3 = Req(2, [10, 10, 10, 10, 10], sampling_params) + req4 = Req(3, [10, 10, 10, 9, 1], sampling_params) + + waiting_list = [] + waiting_list.append(req1) + waiting_list.append(req2) + waiting_list.append(req3) + waiting_list.append(req4) + + tokenizer = LlamaTokenizer.from_pretrained("/data/scratch/llama-7b-hf") + tokenizer.pad_token_id = tokenizer.unk_token_id + model = LlamaForCausalLM.from_pretrained("/data/scratch/llama-7b-hf", pad_token_id=tokenizer.eos_token_id) + model = model.half() + + init_to_get_rotary(model.model, base=10000) + shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) + + infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + start_dynamic_batching(args=args, tp_engine=infer_engine, waiting_req_list=waiting_list) + print("done") + + +def check_dynamic_forward(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run() + + +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_dynamic_batching(): + spawn(check_dynamic_forward, TP_SIZE) + + +if __name__ == "__main__": + test_dynamic_batching() From aacb7b5337bc1e451c526d14615250ab9a064d14 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Sun, 8 Oct 2023 18:27:24 +0800 Subject: [PATCH 06/12] llama infer --- tests/test_infer/test_llama_infer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 93ee72ce8e0c..b424525a3719 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -13,9 +13,9 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" -TPSIZE = 1 -BATCH_SIZE = 2 -MAX_INPUT_LEN = 8 +TPSIZE = 2 +BATCH_SIZE = 8 +MAX_INPUT_LEN = 12 MAX_OUTPUT_LEN = 100 CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") From cfd9bcf011a1ddc9fcbfe414f852f587441544f4 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 9 Oct 2023 15:17:33 +0800 Subject: [PATCH 07/12] finish test --- .../inference/dynamic_batching/infer_batch.py | 4 +- .../inference/dynamic_batching/io_struct.py | 7 ++ .../inference/dynamic_batching/req_queue.py | 5 +- colossalai/inference/manager.py | 4 +- .../tensor_parallel/modeling/_utils.py | 2 +- colossalai/kernel/triton/__init__.py | 2 + .../test_dynamic_batching_manager.py | 94 +++++++++++++++++++ .../test_dynamic_batching/test_forward.py | 11 +-- 8 files changed, 117 insertions(+), 12 deletions(-) create mode 100644 tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py diff --git a/colossalai/inference/dynamic_batching/infer_batch.py b/colossalai/inference/dynamic_batching/infer_batch.py index 244f8fd560d6..c60cb92616d1 100644 --- a/colossalai/inference/dynamic_batching/infer_batch.py +++ b/colossalai/inference/dynamic_batching/infer_batch.py @@ -216,6 +216,7 @@ def filter(self, request_ids: List[int]): out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices], sampling_param_list=[self.sampling_param_list[_i] for _i in indices], cache_manager=self.cache_manager, + max_total_len=self.max_total_len, ) @classmethod @@ -234,7 +235,7 @@ def merge(cls, batch1, batch2): cumulative_batch_size = 0 nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2.nopad_max_len_in_batch) - + max_total_len = max(batch1.max_total_len, batch2.max_total_len) nopad_b_loc = torch.empty((new_batch_size, batch1.max_total_len + 12), dtype=torch.long, device="cuda") nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") @@ -283,6 +284,7 @@ def merge(cls, batch1, batch2): out_token_id_counts=out_token_id_counts, sampling_param_list=sampling_param_list, cache_manager=batches[0].cache_manager, + max_total_len=max_total_len, ) def __len__(self): diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index cd6fbd93ad42..5b66744f338d 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -105,6 +105,10 @@ def mark_finished_req(self, eos_id): return has_new_finish def filter_finished(self): + """ + Filter finished requests from the batch, the finished ones will be removed from 'reqs'. + """ + # TODO: the logic of return should be defined here. unfinished_req = [] for req in self.reqs: if not req.has_generate_finished: @@ -124,6 +128,9 @@ def merge(self, mini_batch): def __repr__(self): return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, " + def __len__(self): + return len(self.reqs) + class BatchTokenIdOut: def __init__(self): diff --git a/colossalai/inference/dynamic_batching/req_queue.py b/colossalai/inference/dynamic_batching/req_queue.py index 61573486d0d1..e7691ea4d3ed 100644 --- a/colossalai/inference/dynamic_batching/req_queue.py +++ b/colossalai/inference/dynamic_batching/req_queue.py @@ -45,7 +45,7 @@ def _can_add_new_req(self, req): else: return False - def generate_new_batch(self, current_batch: Batch): + def generate_new_batch(self, current_batch: Batch = None): if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size: return None self._init_cache_list(current_batch) @@ -69,3 +69,6 @@ def generate_new_batch(self, current_batch: Batch): return new_batch else: return None + + def __len__(self): + return self.waiting_req_list.__len__() diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 4153e18aa8af..09849b05018e 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -19,7 +19,7 @@ def __init__( log_stats=True, log_stats_interval=10, running_batch: Batch = None, - waiting_req_list=[], + waiting_req_list: List = [], ): """ Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager @@ -34,7 +34,7 @@ def __init__( """ self.engine = tp_engine self.max_total_token_num = max_total_token_num - running_max_req_size = self.engine.max_batch_size + running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2 self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list) # all the inputs should be put into req_queue: waiting req list diff --git a/colossalai/inference/tensor_parallel/modeling/_utils.py b/colossalai/inference/tensor_parallel/modeling/_utils.py index e476c3132538..068b64b4f829 100644 --- a/colossalai/inference/tensor_parallel/modeling/_utils.py +++ b/colossalai/inference/tensor_parallel/modeling/_utils.py @@ -45,7 +45,7 @@ def init_to_get_rotary(self, base=10000, use_elem=False): base = float(base) # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ - ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", None)) + ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None) if ntk_alpha is not None: ntk_alpha = float(ntk_alpha) diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 6dafc16dd4ed..070ebe45f659 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -1,5 +1,7 @@ try: import triton + + HAS_TRITON = True except ImportError: HAS_TRITON = False print("Triton is not installed. Please install Triton to use Triton kernels.") diff --git a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py new file mode 100644 index 000000000000..124f1f478b00 --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py @@ -0,0 +1,94 @@ +import pytest +from transformers import LlamaForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig + +import colossalai +from colossalai.inference.dynamic_batching.io_struct import Req +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.manager import DynamicBatchManager +from colossalai.inference.tensor_parallel import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +TP_SIZE = 1 +BATCH_SIZE = 2 +MAX_INPUT_LEN = 5 +MAX_OUTPUT_LEN = 16 + + +def run(): + sampling_params = SamplingParams() + + req1 = Req(0, [1], sampling_params) + req2 = Req(1, [2], sampling_params) + req3 = Req(2, [3], sampling_params) + # req 1-3 are initiliazed as token forward requests + req4 = Req(3, [10, 10, 10, 9, 1], sampling_params) + waiting_list = [] + waiting_list.append(req1) + waiting_list.append(req2) + waiting_list.append(req3) + + # init model and tp engine + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) + model = LlamaForCausalLM(llama_config) + model = model.half() + + shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) + infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + + dynamic_batch_manager = DynamicBatchManager( + tp_engine=infer_engine, + max_total_token_num=42, + batch_max_tokens=42, + eos_id=0, + log_stats=False, + log_stats_interval=10, + waiting_req_list=waiting_list, + ) + before_add = len(dynamic_batch_manager.req_queue) + + # test add req function + dynamic_batch_manager.add_req(req4.prompt_ids, req4.sample_params, req4.request_id) + assert len(dynamic_batch_manager.req_queue.waiting_req_list) == before_add + 1 + + # test abort function + dynamic_batch_manager.abort(req4.request_id) + assert dynamic_batch_manager.req_queue.waiting_req_list[-1].aborted == True + + # test filter batch function, loop_for_fwd, _step, _init_batch and _prefill/_decode batch are tested + batch = dynamic_batch_manager.req_queue.generate_new_batch() + assert len(batch) == 2 + + dynamic_batch_manager._init_batch(batch) + assert dynamic_batch_manager.engine.cache[batch.batch_id] is not None + + batch.reqs[0].has_generate_finished = True + # filter one finished + batch.filter_finished() + dynamic_batch_manager._filter_batch(batch) + assert len(dynamic_batch_manager.engine.cache) == 1 + + # test merge batch + new_batch = dynamic_batch_manager.req_queue.generate_new_batch(batch) + assert len(new_batch) == 1 + dynamic_batch_manager._init_batch(new_batch) + dynamic_batch_manager._merge_batch(batch, new_batch) + + assert len(dynamic_batch_manager.engine.cache[batch.batch_id]) == 2 + + +def check_dynamic_batching_manager(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_dynamic_batching_manager(): + spawn(check_dynamic_batching_manager, 1) + + +if __name__ == "__main__": + test_dynamic_batching_manager() diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py index 57cdd831c685..4e48aabea2bc 100644 --- a/tests/test_infer/test_dynamic_batching/test_forward.py +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -3,7 +3,8 @@ import pytest import torch from packaging import version -from transformers import LlamaForCausalLM, LlamaTokenizer +from transformers import LlamaForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig import colossalai from colossalai.inference.dynamic_batching.io_struct import Req @@ -12,7 +13,6 @@ from colossalai.inference.tensor_parallel import TPInferEngine from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn -from tests.test_infer.test_llama_infer import init_to_get_rotary TP_SIZE = 2 MAX_BATCH_SIZE = 2 @@ -42,17 +42,14 @@ def run(): waiting_list.append(req3) waiting_list.append(req4) - tokenizer = LlamaTokenizer.from_pretrained("/data/scratch/llama-7b-hf") - tokenizer.pad_token_id = tokenizer.unk_token_id - model = LlamaForCausalLM.from_pretrained("/data/scratch/llama-7b-hf", pad_token_id=tokenizer.eos_token_id) + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) + model = LlamaForCausalLM(llama_config) model = model.half() - init_to_get_rotary(model.model, base=10000) shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) start_dynamic_batching(args=args, tp_engine=infer_engine, waiting_req_list=waiting_list) - print("done") def check_dynamic_forward(rank, world_size, port): From b358232c015aa07dd1968a5ce6fa625d4a3bb164 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 9 Oct 2023 22:45:19 +0800 Subject: [PATCH 08/12] support different lengths generating --- .../inference/dynamic_batching/infer_batch.py | 2 -- colossalai/inference/manager.py | 2 ++ .../inference/tensor_parallel/engine.py | 10 ++++++- .../tensor_parallel/kvcache_manager.py | 2 +- .../tensor_parallel/modeling/llama.py | 30 ++++++++++--------- .../test_dynamic_batching/test_forward.py | 6 ++-- 6 files changed, 31 insertions(+), 21 deletions(-) diff --git a/colossalai/inference/dynamic_batching/infer_batch.py b/colossalai/inference/dynamic_batching/infer_batch.py index c60cb92616d1..7ecd10544a16 100644 --- a/colossalai/inference/dynamic_batching/infer_batch.py +++ b/colossalai/inference/dynamic_batching/infer_batch.py @@ -151,7 +151,6 @@ def filter(self, request_ids: List[int]): requests = [] all_input_ids = [] input_lengths = [] - nopad_total_token_num = 0 nopad_max_len_in_batch = 0 nopad_b_loc = torch.empty((len(request_ids), self.max_total_len + 12), dtype=torch.long, device="cuda") @@ -175,7 +174,6 @@ def filter(self, request_ids: List[int]): ] ) remove_index = torch.cat(remove_index, dim=-1) - self.cache_manager.free(remove_index) nopad_max_len_in_batch = 0 diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 09849b05018e..7894380a7af8 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -151,6 +151,7 @@ def _prefill_batch(self, batch): For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. """ self._init_batch(batch) + # TODO: figure out if cache and batch id is needed ans = self.engine._prefill_batch(batch.batch_id) req_to_out_token_id = ans @@ -186,6 +187,7 @@ def _merge_batch(self, batch1, batch2): m_batch = InferBatch.merge(batch1, batch2) self.engine.cache[batch1.batch_id] = m_batch + print("merged_batch", m_batch) del batch1 del batch2 diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index a55ddabe9098..061f0a58ee91 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -377,9 +377,10 @@ def _update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: @torch.no_grad() def forward(self, batch_id, is_prefill): """ - Forward is used in Dynamic Batching + Forward is used in Dynamic Batching Manager """ batch = self.cache.pop(batch_id) + if is_prefill: input_ = torch.tensor(batch.all_input_ids).cuda() else: @@ -426,6 +427,8 @@ def forward(self, batch_id, is_prefill): ): next_token_id = int(next_token_id) next_token_logprob = next_token_logprob[next_token_id] + if r["request_id"] == 0 and is_prefill == False: + next_token_id = 0 # all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device="cuda") all_input_ids.append(next_token_id) # all_input_ids_tensor = None @@ -443,6 +446,11 @@ def forward(self, batch_id, is_prefill): batch.nopad_total_token_num += len(batch) batch.nopad_max_len_in_batch += 1 self.cache[batch.batch_id] = batch + + if len(batch) == 1: + print("filterd batch", batch) + print(" ") + return output_dict @torch.no_grad() diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py index dd6f9d7c8fc6..4e414c71f4a3 100644 --- a/colossalai/inference/tensor_parallel/kvcache_manager.py +++ b/colossalai/inference/tensor_parallel/kvcache_manager.py @@ -95,7 +95,7 @@ def free(self, free_index): self.available_size += free_index.shape[0] self.mem_state[free_index] = 1 # FIXME: should this be zero? - self.past_key_values_length = 0 + # self.past_key_values_length = 0 @torch.no_grad() def free_all(self): diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 14f14e0727bb..d777a964cee8 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -77,12 +77,12 @@ def llama_model_forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - seq_length_with_past = seq_length - # NOT READY FOR PRIME TIME # dummy but work, revise it - past_key_values_length = infer_state.cache_manager.past_key_values_length - seq_length_with_past = seq_length_with_past + past_key_values_length + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 # NOTE: differentiate with prefill stage # block_loc require different value-assigning method for two different stage @@ -102,18 +102,21 @@ def llama_model_forward( infer_state.decode_mem_index = alloc_mem[0] infer_state.decode_mem_start = alloc_mem[1] infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index else: print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) + print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}") infer_state.decode_is_contiguous = False alloc_mem = infer_state.cache_manager.alloc(batch_size) infer_state.decode_mem_index = alloc_mem # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index + + if batch_size == 2: + print("in decoding", infer_state.block_loc) + print(infer_state.max_len_in_batch) + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( @@ -123,6 +126,7 @@ def llama_model_forward( position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() + if infer_state.is_context_stage: infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( position_ids.view(-1).shape[0], -1 @@ -142,7 +146,7 @@ def llama_model_forward( # embed positions if attention_mask is None: attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + (batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=inputs_embeds.device ) attention_mask = self._prepare_decoder_attention_mask( @@ -262,8 +266,6 @@ def llama_flash_attn_kvcache_forward( # NOTE might want to revise # need some way to record the length of past key values cache # since we won't return past_key_value_cache right now - if infer_state.decode_layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_len # seq_len cos, sin = infer_state.position_cos, infer_state.position_sin @@ -292,7 +294,7 @@ def llama_flash_attn_kvcache_forward( attn_output, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) else: if infer_state.decode_is_contiguous: @@ -329,7 +331,7 @@ def llama_flash_attn_kvcache_forward( infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) attn_output = attn_output.view(bsz, q_len, self.hidden_size) diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py index 4e48aabea2bc..1faf851e829b 100644 --- a/tests/test_infer/test_dynamic_batching/test_forward.py +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -14,7 +14,7 @@ from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn -TP_SIZE = 2 +TP_SIZE = 1 MAX_BATCH_SIZE = 2 MAX_INPUT_LEN = 5 MAX_OUTPUT_LEN = 16 @@ -33,8 +33,8 @@ def run(): req1 = Req(0, [0, 0, 10, 6, 8], sampling_params) req2 = Req(1, [10, 10, 10, 10, 10], sampling_params) - req3 = Req(2, [10, 10, 10, 10, 10], sampling_params) - req4 = Req(3, [10, 10, 10, 9, 1], sampling_params) + req3 = Req(2, [0, 0, 10, 10, 10], sampling_params) + req4 = Req(3, [0, 0, 10, 10, 10], sampling_params) waiting_list = [] waiting_list.append(req1) From f395cd75ce43d254a4670cac564b7ef155cba8b0 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 10 Oct 2023 11:19:07 +0800 Subject: [PATCH 09/12] del prints --- colossalai/inference/tensor_parallel/engine.py | 7 ------- colossalai/inference/tensor_parallel/kvcache_manager.py | 2 -- colossalai/inference/tensor_parallel/modeling/llama.py | 4 ---- 3 files changed, 13 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 061f0a58ee91..f7fb7a825694 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -427,8 +427,6 @@ def forward(self, batch_id, is_prefill): ): next_token_id = int(next_token_id) next_token_logprob = next_token_logprob[next_token_id] - if r["request_id"] == 0 and is_prefill == False: - next_token_id = 0 # all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device="cuda") all_input_ids.append(next_token_id) # all_input_ids_tensor = None @@ -446,11 +444,6 @@ def forward(self, batch_id, is_prefill): batch.nopad_total_token_num += len(batch) batch.nopad_max_len_in_batch += 1 self.cache[batch.batch_id] = batch - - if len(batch) == 1: - print("filterd batch", batch) - print(" ") - return output_dict @torch.no_grad() diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py index 4e414c71f4a3..e74a3a491a7b 100644 --- a/colossalai/inference/tensor_parallel/kvcache_manager.py +++ b/colossalai/inference/tensor_parallel/kvcache_manager.py @@ -94,8 +94,6 @@ def free(self, free_index): """free memory by updating memory states based on given indexes""" self.available_size += free_index.shape[0] self.mem_state[free_index] = 1 - # FIXME: should this be zero? - # self.past_key_values_length = 0 @torch.no_grad() def free_all(self): diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index d777a964cee8..958868a0974e 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -113,10 +113,6 @@ def llama_model_forward( # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index - if batch_size == 2: - print("in decoding", infer_state.block_loc) - print(infer_state.max_len_in_batch) - if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( From 0f205c1f3520aec008339bf0b57940318b29a6ff Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 10 Oct 2023 11:21:04 +0800 Subject: [PATCH 10/12] del prints --- colossalai/inference/manager.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index 7894380a7af8..dd7336127f70 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -187,7 +187,6 @@ def _merge_batch(self, batch1, batch2): m_batch = InferBatch.merge(batch1, batch2) self.engine.cache[batch1.batch_id] = m_batch - print("merged_batch", m_batch) del batch1 del batch2 @@ -239,6 +238,5 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list): batch_manager.clean_up() raise - print("start dynamic batching process") batch_manager.loop_for_fwd() return From d288f104742adb6403691419f8fc51498e53285f Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 10 Oct 2023 18:26:55 +0800 Subject: [PATCH 11/12] fix --- .../inference/dynamic_batching/infer_batch.py | 24 +++++++++++++------ .../inference/dynamic_batching/io_struct.py | 6 ++--- .../inference/dynamic_batching/req_queue.py | 7 ++---- .../dynamic_batching/sampling_params.py | 2 +- .../dynamic_batching/{stas.py => stats.py} | 0 colossalai/inference/manager.py | 7 +++--- tests/kit/model_zoo/transformers/llama.py | 6 ++--- 7 files changed, 29 insertions(+), 23 deletions(-) rename colossalai/inference/dynamic_batching/{stas.py => stats.py} (100%) diff --git a/colossalai/inference/dynamic_batching/infer_batch.py b/colossalai/inference/dynamic_batching/infer_batch.py index 7ecd10544a16..826272db3e11 100644 --- a/colossalai/inference/dynamic_batching/infer_batch.py +++ b/colossalai/inference/dynamic_batching/infer_batch.py @@ -1,6 +1,6 @@ import collections from dataclasses import dataclass -from typing import Dict, List +from typing import Dict, List , Tuple import numpy as np import torch @@ -65,7 +65,7 @@ def init_batch( cache_manager: MemoryManager, vocab_size: int, max_total_len: int, - ): + ) -> 'InferBatch': input_lengths = [] all_input_ids = [] requests_idx_mapping = {} @@ -76,6 +76,7 @@ def init_batch( nopad_total_token_num = 0 nopad_max_len_in_batch = 0 nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda") + # to avoid memory leak , we pre-allocate 12 more space for each batch. nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda") for i, r in enumerate(requests): # request id -> idx in list mapping @@ -126,7 +127,10 @@ def init_batch( ) @torch.no_grad() - def free_self(self): + def free_self(self) -> None: + """ + Free the memory of the InferBatch itself + """ remove_index = [] for idx in range(len(self)): remove_index.append( @@ -138,10 +142,13 @@ def free_self(self): ) remove_index = torch.cat(remove_index, dim=-1) self.cache_manager.free(remove_index) - return + @torch.no_grad() - def filter(self, request_ids: List[int]): + def filter(self, request_ids: List[int]) -> 'InferBatch': + """ + Filter finished batch and return a new InferBatch with left ones. + """ if len(request_ids) == 0: raise ValueError("Batch must have at least one request") if len(request_ids) == len(self): @@ -219,7 +226,10 @@ def filter(self, request_ids: List[int]): @classmethod @torch.no_grad() - def merge(cls, batch1, batch2): + def merge(cls, batch1, batch2) -> 'InferBatch': + """ + Return megerd new InferBatch + """ requests = batch1.requests + batch2.requests requests_idx_mapping = {} new_batch_size = len(batch1) + len(batch2) @@ -288,7 +298,7 @@ def merge(cls, batch1, batch2): def __len__(self): return len(self.requests) - def get_post_sample_tensors(self): + def get_post_sample_tensors(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: presence_penalties: List[float] = [] frequency_penalties: List[float] = [] temperatures: List[float] = [] diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 5b66744f338d..2b2739f0ae90 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -36,10 +36,8 @@ def stop_sequences_matched(self): if self.sample_params.stop_sequences is not None: for stop_token_ids in self.sample_params.stop_sequences: stop_len = len(stop_token_ids) - if stop_len > 0: - if len(self.output_ids) >= stop_len: - if all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)): - return True + if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)): + return True return False def __repr__(self): diff --git a/colossalai/inference/dynamic_batching/req_queue.py b/colossalai/inference/dynamic_batching/req_queue.py index e7691ea4d3ed..d9e9b6269cc4 100644 --- a/colossalai/inference/dynamic_batching/req_queue.py +++ b/colossalai/inference/dynamic_batching/req_queue.py @@ -40,11 +40,8 @@ def _can_add_new_req(self, req): need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() # NOTE: change here < to <= - if need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size: - return True - else: - return False - + return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size + def generate_new_batch(self, current_batch: Batch = None): if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size: return None diff --git a/colossalai/inference/dynamic_batching/sampling_params.py b/colossalai/inference/dynamic_batching/sampling_params.py index 8af532dfa39c..9a0ace4111dd 100644 --- a/colossalai/inference/dynamic_batching/sampling_params.py +++ b/colossalai/inference/dynamic_batching/sampling_params.py @@ -16,7 +16,7 @@ def __init__( top_k: int = -1, # -1 is for all ignore_eos: bool = False, max_new_tokens: int = 16, - stop_sequences: Optional[Union[str, List[str]]] = None # 停止句子条件 + stop_sequences: Optional[Union[str, List[str]]] = None # conditions to stop generation ) -> None: self.do_sample = do_sample self.presence_penalty = presence_penalty diff --git a/colossalai/inference/dynamic_batching/stas.py b/colossalai/inference/dynamic_batching/stats.py similarity index 100% rename from colossalai/inference/dynamic_batching/stas.py rename to colossalai/inference/dynamic_batching/stats.py diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index dd7336127f70..cc8d1994bba9 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -5,7 +5,7 @@ from .dynamic_batching.io_struct import Batch, Req from .dynamic_batching.req_queue import ReqQueue from .dynamic_batching.sampling_params import SamplingParams -from .dynamic_batching.stas import Stats +from .dynamic_batching.stats import Stats from .tensor_parallel import TPInferEngine @@ -42,8 +42,9 @@ def __init__( self.eos_id = eos_id self.has_wait_tokens = 0 self.max_wait_tokens = 10 - + self.stats_tool = Stats(log_stats, log_stats_interval) + self.mem_usage_interval = log_stats_interval * 2 def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str): """ @@ -74,7 +75,7 @@ def loop_for_fwd(self): self._step() counter_count += 1 if self.running_batch is not None: - if counter_count % 50 == 0: + if counter_count % self.mem_usage_interval == 0: print( "current batch size:", len(self.running_batch.reqs), diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index cd23a36cd3d3..1d1e154b6e70 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -49,10 +49,10 @@ def data_gen_for_casual_lm(): loss_fn_for_seq_classification = lambda output: output.logits.mean() config = LlamaConfig( - num_hidden_layers=8, - hidden_size=4096, + num_hidden_layers=4, + hidden_size=128, intermediate_size=256, - num_attention_heads=32, + num_attention_heads=4, max_position_embeddings=128, num_labels=16, ) From c1683cf3725aa9225c4d9618b3ea5b60413da8ea Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 11 Oct 2023 10:03:14 +0800 Subject: [PATCH 12/12] fix bug --- colossalai/inference/manager.py | 28 +++++++++---------- .../test_dynamic_batching/test_forward.py | 21 +++++++------- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index cc8d1994bba9..72f77406789f 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -224,20 +224,20 @@ def clean_up(self): def start_dynamic_batching(args, tp_engine, waiting_req_list): - try: - batch_manager = DynamicBatchManager( - tp_engine=tp_engine, - max_total_token_num=args.max_total_token_num, - batch_max_tokens=args.batch_max_tokens, - eos_id=args.eos_id, - log_stats=not args.disable_log_stats, - log_stats_interval=args.log_stats_interval, - waiting_req_list=waiting_req_list, - ) - - except Exception: - batch_manager.clean_up() - raise + # try: + batch_manager = DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + # except Exception: + # batch_manager.clean_up() + # raise batch_manager.loop_for_fwd() return diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py index 1faf851e829b..63df491e5b52 100644 --- a/tests/test_infer/test_dynamic_batching/test_forward.py +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -1,5 +1,3 @@ -import argparse - import pytest import torch from packaging import version @@ -7,6 +5,7 @@ from transformers.models.llama.configuration_llama import LlamaConfig import colossalai +from dataclasses import dataclass from colossalai.inference.dynamic_batching.io_struct import Req from colossalai.inference.dynamic_batching.sampling_params import SamplingParams from colossalai.inference.manager import start_dynamic_batching @@ -20,15 +19,17 @@ MAX_OUTPUT_LEN = 16 CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") +@dataclass +class args: + max_total_token_num: int + batch_max_tokens: int + eos_id: int + disable_log_stats: bool + log_stats_interval: int + def run(): - parser = argparse.ArgumentParser() - parser.add_argument("--max_total_token_num", type=int, default=42, help="max_total_token_num") - parser.add_argument("-b", "--batch_max_tokens", type=int, default=42, help="max tokens of one batch") - parser.add_argument("--eos_id", type=int, default=0, help="The end token of a seq") - parser.add_argument("--disable_log_stats", type=bool, default=False) - parser.add_argument("--log_stats_interval", type=int, default=10) - args = parser.parse_args() + arg = args(max_total_token_num=42, batch_max_tokens=42, eos_id=0, disable_log_stats=False, log_stats_interval=10) sampling_params = SamplingParams() req1 = Req(0, [0, 0, 10, 6, 8], sampling_params) @@ -49,7 +50,7 @@ def run(): shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - start_dynamic_batching(args=args, tp_engine=infer_engine, waiting_req_list=waiting_list) + start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) def check_dynamic_forward(rank, world_size, port):