From 6779d5bebffb1f1e00b5470a060bb389fdf263b8 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 8 Nov 2023 10:55:36 +0800 Subject: [PATCH 1/5] remove useless code --- colossalai/inference/async_engine.py | 133 ----- colossalai/inference/async_manager.py | 151 ----- .../inference/dynamic_batching/__init__.py | 0 .../dynamic_batching/get_tokenizer.py | 40 -- .../inference/dynamic_batching/infer_batch.py | 346 ----------- .../inference/dynamic_batching/io_struct.py | 166 ------ .../dynamic_batching/ray_dist_init.py | 152 ----- .../dynamic_batching/ray_init_config.py | 58 -- .../inference/dynamic_batching/req_queue.py | 73 --- .../dynamic_batching/sampling_params.py | 83 --- .../inference/dynamic_batching/stats.py | 45 -- colossalai/inference/hybridengine/engine.py | 4 +- .../microbatch_manager.py | 3 +- .../inference/hybridengine/modeling/bloom.py | 2 +- .../inference/hybridengine/modeling/llama.py | 2 +- .../inference/kvcache_manager/__init__.py | 2 + .../batch_infer_state.py | 0 .../kvcache_manager.py | 0 colossalai/inference/manager.py | 296 ---------- colossalai/inference/pipeline/README.md | 83 --- colossalai/inference/pipeline/__init__.py | 3 - .../inference/tensor_parallel/__init__.py | 4 - .../inference/tensor_parallel/engine.py | 480 --------------- .../tensor_parallel/modeling/__init__.py | 5 - .../tensor_parallel/modeling/_utils.py | 67 --- .../tensor_parallel/modeling/bloom.py | 537 ----------------- .../tensor_parallel/modeling/chatglm2.py | 545 ------------------ .../tensor_parallel/modeling/llama.py | 423 -------------- .../tensor_parallel/policies/__init__.py | 5 - .../tensor_parallel/policies/bloom.py | 99 ---- .../tensor_parallel/policies/chatglm2.py | 77 --- .../tensor_parallel/policies/llama.py | 119 ---- colossalai/pipeline/schedule/generate.py | 2 +- .../inference}/benchmark.py | 0 .../benchmark => examples/inference}/run.sh | 0 tests/test_infer/test_bloom_infer.py | 70 --- tests/test_infer/test_chatglm2_infer.py | 83 --- .../test_dynamic_batching/config.yaml | 14 - .../test_async_engine.py | 61 -- .../test_dynamic_batching_manager.py | 95 --- .../test_offline_dynamic_batching.py | 84 --- .../test_dynamic_batching/test_ray_dist.py | 66 --- tests/test_infer/test_infer_engine.py | 102 ---- tests/test_infer/test_kvcache_manager.py | 2 +- tests/test_infer/test_llama2_infer.py | 75 --- tests/test_infer/test_llama_infer.py | 73 --- 46 files changed, 9 insertions(+), 4721 deletions(-) delete mode 100644 colossalai/inference/async_engine.py delete mode 100644 colossalai/inference/async_manager.py delete mode 100644 colossalai/inference/dynamic_batching/__init__.py delete mode 100644 colossalai/inference/dynamic_batching/get_tokenizer.py delete mode 100644 colossalai/inference/dynamic_batching/infer_batch.py delete mode 100644 colossalai/inference/dynamic_batching/io_struct.py delete mode 100644 colossalai/inference/dynamic_batching/ray_dist_init.py delete mode 100644 colossalai/inference/dynamic_batching/ray_init_config.py delete mode 100644 colossalai/inference/dynamic_batching/req_queue.py delete mode 100644 colossalai/inference/dynamic_batching/sampling_params.py delete mode 100644 colossalai/inference/dynamic_batching/stats.py rename colossalai/inference/{pipeline => hybridengine}/microbatch_manager.py (98%) create mode 100644 colossalai/inference/kvcache_manager/__init__.py rename colossalai/inference/{tensor_parallel => kvcache_manager}/batch_infer_state.py (100%) rename colossalai/inference/{tensor_parallel => kvcache_manager}/kvcache_manager.py (100%) delete mode 100644 colossalai/inference/manager.py delete mode 100644 colossalai/inference/pipeline/README.md delete mode 100644 colossalai/inference/pipeline/__init__.py delete mode 100644 colossalai/inference/tensor_parallel/__init__.py delete mode 100644 colossalai/inference/tensor_parallel/engine.py delete mode 100644 colossalai/inference/tensor_parallel/modeling/__init__.py delete mode 100644 colossalai/inference/tensor_parallel/modeling/_utils.py delete mode 100644 colossalai/inference/tensor_parallel/modeling/bloom.py delete mode 100644 colossalai/inference/tensor_parallel/modeling/chatglm2.py delete mode 100644 colossalai/inference/tensor_parallel/modeling/llama.py delete mode 100644 colossalai/inference/tensor_parallel/policies/__init__.py delete mode 100644 colossalai/inference/tensor_parallel/policies/bloom.py delete mode 100644 colossalai/inference/tensor_parallel/policies/chatglm2.py delete mode 100644 colossalai/inference/tensor_parallel/policies/llama.py rename {colossalai/inference/pipeline/benchmark => examples/inference}/benchmark.py (100%) rename {colossalai/inference/pipeline/benchmark => examples/inference}/run.sh (100%) delete mode 100644 tests/test_infer/test_bloom_infer.py delete mode 100644 tests/test_infer/test_chatglm2_infer.py delete mode 100644 tests/test_infer/test_dynamic_batching/config.yaml delete mode 100644 tests/test_infer/test_dynamic_batching/test_async_engine.py delete mode 100644 tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py delete mode 100644 tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py delete mode 100644 tests/test_infer/test_dynamic_batching/test_ray_dist.py delete mode 100644 tests/test_infer/test_infer_engine.py delete mode 100644 tests/test_infer/test_llama2_infer.py delete mode 100644 tests/test_infer/test_llama_infer.py diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py deleted file mode 100644 index d0890ba3e9fc..000000000000 --- a/colossalai/inference/async_engine.py +++ /dev/null @@ -1,133 +0,0 @@ -import asyncio - -from colossalai.inference.dynamic_batching.ray_dist_init import Driver - -from .dynamic_batching.io_struct import RequestOutput -from .dynamic_batching.sampling_params import SamplingParams - - -class RequestTracker: - """ - A class for trace down all the requests, abstraction for async - """ - - def __init__(self) -> None: - self._requests: asyncio.Queue[str] = asyncio.Queue() - self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue() - self.new_requests_event = None - - def __contains__(self, item): - return item in self._requests - - def init_event(self): - self.new_requests_event = asyncio.Event() - - def add_request(self, request_id: str): - """Add a request to be sent to the engine on the next background - loop iteration.""" - self._requests.put_nowait(request_id) - self.new_requests_event.set() # NOTE: we may find a better way to clear this event - - def add_stop(self): - """ - Add a StopIteration flag to stop async generator. - """ - self._finished_requests.put_nowait(StopIteration) - self.new_requests_event.clear() - - def process_request_output(self, request_output: RequestOutput) -> None: - """Process a request output from the engine.""" - self._finished_requests.put_nowait(request_output) - - async def wait_for_new_requests(self): - await self.new_requests_event.wait() - - def __aiter__(self): - return self - - async def __anext__(self) -> RequestOutput: - result = await self._finished_requests.get() - # print("result of ", result) - if result is StopIteration: - raise StopAsyncIteration - return result - - -class Async_Engine: - - """ - Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager - Background loop: inference reqs in waiting list (Listen) - Request Tracker: manage incoming requests and restore finished ones - Generate: exposed func for add new input and return finished ones - """ - - def __init__( - self, - router_config, - engine_config, - start_engine_loop: bool = True, - ) -> None: - self.driver = Driver(router_config=router_config, engine_config=engine_config) - self.background_loop = None - self.start_engine_loop = start_engine_loop - self._request_tracker = RequestTracker() - - def _step(self): - """ - Logic for handling requests - """ - request_outputs = self.driver.step() - if request_outputs is not None: - for request_output in request_outputs: - self._request_tracker.process_request_output(request_output) - self._request_tracker.add_stop() - - def abort_request(self, request_id: str): - self.driver.abort(request_id) - - def _has_requests_in_progress(self): - return self.driver.is_running() - - async def run_loop_fwd(self): - has_requests_in_progress = self._has_requests_in_progress() - while True: - if not has_requests_in_progress: - await self._request_tracker.wait_for_new_requests() - self._step() - await asyncio.sleep(0) - - @property - def is_running(self): - return self.background_loop is not None and not self.background_loop.done() - - def start_background_loop(self): - if self.is_running: - raise RuntimeError("Background loop is already running.") - - self._request_tracker.init_event() - - self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd()) - self.background_loop = asyncio.shield(self.background_loop_unshielded) - - async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams): - self.driver.add_input(request_id, prompt, sampling_params) - self._request_tracker.add_request(request_id) - - async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): - """ - The only exposed func, adding new request and return a async generator that yields the existing results. - """ - try: - if not self.is_running: - self.start_background_loop() - - await self.add_request(request_id, prompt, sampling_params) - - async for request_output in self._request_tracker: - yield request_output - - except (Exception, asyncio.CancelledError) as e: - # If there is an exception or coroutine is cancelled, abort the request. - self.abort_request(request_id) - raise e diff --git a/colossalai/inference/async_manager.py b/colossalai/inference/async_manager.py deleted file mode 100644 index 60440a792f1c..000000000000 --- a/colossalai/inference/async_manager.py +++ /dev/null @@ -1,151 +0,0 @@ -from typing import List - -from .dynamic_batching.io_struct import Batch, Req, RequestOutput -from .manager import DynamicBatchManager -from .tensor_parallel import TPInferEngine - - -class Async_DynamicBatchManager(DynamicBatchManager): - def __init__( - self, - tp_engine: TPInferEngine, - max_total_token_num: int, - batch_max_tokens: int, - model: str, - tokenizer=None, - eos_id=None, - log_stats=True, - log_stats_interval=10, - running_batch: Batch = None, - waiting_req_list: List = [], - ): - """ - Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager - max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) - batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests - running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine - eos_id : The end token of a seq - model: the model weight dir path, the app will load config, weights and tokenizer from this dir - log_stats : whether to log stats - log_stats_interval : log stats interval - running_batch : running batch - waiting_req_list : list of waiting requests, initialized before dynamic batch manager - """ - super().__init__( - tp_engine, - max_total_token_num, - batch_max_tokens, - model, - tokenizer, - eos_id, - log_stats, - log_stats_interval, - running_batch, - waiting_req_list, - ) - - def _step(self): - """ - Logic for handling requests - """ - has_new_finished = False - if self.running_batch is None: - new_batch = self.req_queue.generate_new_batch(self.running_batch) - if new_batch is not None: - self.stats_tool.count_prompt_tokens(new_batch) - self.running_batch = new_batch - has_new_finished, outputs = self._prefill_batch(self.running_batch) - self._filter_runing_batch() - self.has_wait_tokens = 0 - - else: - if self.has_wait_tokens < self.max_wait_tokens: - self.stats_tool.count_output_tokens(self.running_batch) - has_new_finished, outputs = self._decode_batch(self.running_batch) - self._filter_runing_batch() - self.has_wait_tokens += 1 - - else: - new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) - if new_mini_batch is not None: - self.stats_tool.count_prompt_tokens(new_mini_batch) - has_new_finished, outputs = self._prefill_batch(new_mini_batch) - if not new_mini_batch.is_clear(): - self._merge_batch(self.running_batch, new_mini_batch) - self.running_batch.merge(new_mini_batch) - self.has_wait_tokens = 0 - - else: - self.stats_tool.count_output_tokens(self.running_batch) - has_new_finished, outputs = self._decode_batch(self.running_batch) - self._filter_runing_batch() - self.has_wait_tokens += 1 - - if has_new_finished: - return outputs - return None - - def _prefill_batch(self, batch): - """ - For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. - """ - self._init_batch(batch) - - # TODO: figure out if cache and batch id is needed - ans = self.engine._prefill_batch(batch.batch_id) - req_to_out_token_id = ans - self._add_token_id_to_req(batch, req_to_out_token_id) - has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) - outputs = self._handle_finish_req(batch, has_new_finished_req) - return has_new_finished_req, outputs - # delete finished reqs - - def _decode_batch(self, batch: Batch): - """ - Decoding process - """ - ans = self.engine._decode_batch(batch.batch_id) - req_to_out_token_id = ans - self._add_token_id_to_req(batch, req_to_out_token_id) - has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) - outputs = self._handle_finish_req(batch, has_new_finished_req) - return has_new_finished_req, outputs - - def _handle_finish_req(self, batch: Batch, has_new_finished_req): - if has_new_finished_req: - finished_reqs = batch.filter_finished() - if batch.is_clear(): - self._remove_batch(batch) - else: - self._filter_batch(batch) - return self._output_process(finished_reqs) - return None - - def _output_process(self, finished_reqs: List[Req]): - """ - Process the output of a batch. - """ - outputs = [] - for req in finished_reqs: - output = self.tokenizer.decode(req.output_ids) - outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output)) - return outputs - - -def start_dynamic_batching(args, tp_engine, waiting_req_list): - try: - batch_manager = Async_DynamicBatchManager( - tp_engine=tp_engine, - max_total_token_num=args.max_total_token_num, - batch_max_tokens=args.batch_max_tokens, - eos_id=args.eos_id, - model=args.model, - log_stats=not args.disable_log_stats, - log_stats_interval=args.log_stats_interval, - waiting_req_list=waiting_req_list, - ) - - except Exception: - raise Exception - - return batch_manager diff --git a/colossalai/inference/dynamic_batching/__init__.py b/colossalai/inference/dynamic_batching/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/colossalai/inference/dynamic_batching/get_tokenizer.py b/colossalai/inference/dynamic_batching/get_tokenizer.py deleted file mode 100644 index 94aa3f24393f..000000000000 --- a/colossalai/inference/dynamic_batching/get_tokenizer.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue. - -license: MIT, see LICENSE for more details. -""" - -from transformers import AutoTokenizer - -_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" - - -def get_tokenizer( - tokenizer=None, - tokenizer_name: str = "", - trust_remote_code: bool = False, - use_fast: bool = True, -): - if tokenizer is not None: - tokenizer = tokenizer - else: - if "llama" in tokenizer_name.lower() and use_fast == True: - print( - "For some LLaMA-based models, initializing the fast tokenizer may " - "take a long time. To eliminate the initialization time, consider " - f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " - "tokenizer. This is done automatically in Colossalai." - ) - - tokenizer_name = _FAST_LLAMA_TOKENIZER - - try: - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code - ) - except TypeError: - use_fast = False - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code - ) - return tokenizer diff --git a/colossalai/inference/dynamic_batching/infer_batch.py b/colossalai/inference/dynamic_batching/infer_batch.py deleted file mode 100644 index 112784c15f84..000000000000 --- a/colossalai/inference/dynamic_batching/infer_batch.py +++ /dev/null @@ -1,346 +0,0 @@ -# Adapted from https://github.com/ModelTC/lightllm - -import collections -from dataclasses import dataclass -from typing import Dict, List, Tuple - -import numpy as np -import torch - -from colossalai.inference.tensor_parallel import MemoryManager - - -# make batch infer state an attr of InferBatch -class InferSamplingParams: - def __init__( - self, - do_sample: bool = False, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = -1, - vocab_size: int = -1, - ) -> None: - self.do_sample = do_sample - self.presence_penalty = presence_penalty - self.frequency_penalty = frequency_penalty - self.temperature = temperature - self.top_p = top_p - self.top_k = top_k - if self.top_k == -1: - self.top_k = vocab_size - return - - -@dataclass -class InferBatch: - batch_id: int - requests: List - requests_idx_mapping: Dict[int, int] - - input_ids: torch.Tensor - - all_input_ids: List[List[int]] - input_lengths: List[int] - - out_token_id_counts: List - sampling_param_list: List[InferSamplingParams] - - nopad_total_token_num: int - nopad_max_len_in_batch: int - nopad_b_loc: torch.Tensor - nopad_b_start_loc: torch.Tensor - nopad_b_seq_len: torch.Tensor - cache_manager: MemoryManager - max_total_len: int - - @classmethod - @torch.no_grad() - def init_batch( - cls, - batch_id, - requests, - dtype: torch.dtype, - device: torch.device, - cache_manager: MemoryManager, - vocab_size: int, - max_total_len: int, - ) -> "InferBatch": - input_lengths = [] - all_input_ids = [] - requests_idx_mapping = {} - - out_token_id_counts = [] - sampling_param_list = [] - - nopad_total_token_num = 0 - nopad_max_len_in_batch = 0 - nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda") - # to avoid memory leak , we pre-allocate 12 more space for each batch. - nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda") - for i, r in enumerate(requests): - # request id -> idx in list mapping - requests_idx_mapping[r["request_id"]] = i - - tokenized_input = r["input_id"] - - input_length = len(tokenized_input) - input_lengths.append(input_length) - all_input_ids.append(tokenized_input) - out_token_id_counts.append(collections.defaultdict(int)) - - # postprocessor - sampling_param = r["sampling_param"] - sampling_param["vocab_size"] = vocab_size - sampling_param_list.append(InferSamplingParams(**sampling_param)) - - nopad_total_token_num += input_length - nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length) - - nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda") - nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] - - if len(requests) > 1: - input_ids = np.concatenate(all_input_ids, dtype=np.int64) - else: - input_ids = all_input_ids[0] - - # Create tensors on device - input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - - return cls( - batch_id=batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - input_lengths=input_lengths, - all_input_ids=all_input_ids, - nopad_total_token_num=nopad_total_token_num, - nopad_max_len_in_batch=nopad_max_len_in_batch, - nopad_b_loc=nopad_b_loc, - nopad_b_start_loc=nopad_b_start_loc, - nopad_b_seq_len=nopad_b_seq_len, - out_token_id_counts=out_token_id_counts, - sampling_param_list=sampling_param_list, - cache_manager=cache_manager, - max_total_len=max_total_len, - ) - - @torch.no_grad() - def free_self(self) -> None: - """ - Free the memory of the InferBatch itself - """ - remove_index = [] - for idx in range(len(self)): - remove_index.append( - self.nopad_b_loc[ - idx, - (self.nopad_max_len_in_batch - 1) - - (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1), - ] - ) - remove_index = torch.cat(remove_index, dim=-1) - self.cache_manager.free(remove_index) - - @torch.no_grad() - def filter(self, request_ids: List[int]) -> "InferBatch": - """ - Filter finished batch and return a new InferBatch with left ones. - """ - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") - if len(request_ids) == len(self): - return self - requests_idx_mapping = {} - indices = [] - requests = [] - all_input_ids = [] - input_lengths = [] - nopad_total_token_num = 0 - nopad_max_len_in_batch = 0 - nopad_b_loc = torch.empty((len(request_ids), self.max_total_len + 12), dtype=torch.long, device="cuda") - nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda") - nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda") - - left_idx = [] - for i, request_id in enumerate(request_ids): - idx = self.requests_idx_mapping[request_id] - left_idx.append(idx) - - left_idx_set = set(left_idx) - remove_index = [] - for idx in range(len(self)): - if idx not in left_idx_set: - remove_index.append( - self.nopad_b_loc[ - idx, - (self.nopad_max_len_in_batch - 1) - - (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1), - ] - ) - remove_index = torch.cat(remove_index, dim=-1) - self.cache_manager.free(remove_index) - - nopad_max_len_in_batch = 0 - for i, request_id in enumerate(request_ids): - idx = self.requests_idx_mapping[request_id] - indices.append(idx) - - nopad_b_seq_len[:] = self.nopad_b_seq_len[indices] - nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item() - nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] - nopad_total_token_num = torch.sum(nopad_b_seq_len).item() - - nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[ - indices, - (self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1) : (self.nopad_max_len_in_batch - 1), - ] - for i, request_id in enumerate(request_ids): - idx = self.requests_idx_mapping[request_id] - requests_idx_mapping[request_id] = i - requests.append(self.requests[idx]) - all_input_ids.append(self.all_input_ids[idx]) - input_lengths.append(self.input_lengths[idx]) - - input_ids = self.input_ids[indices] - - return InferBatch( - batch_id=self.batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - input_lengths=input_lengths, - all_input_ids=all_input_ids, - nopad_total_token_num=nopad_total_token_num, - nopad_max_len_in_batch=nopad_max_len_in_batch, - nopad_b_loc=nopad_b_loc, - nopad_b_start_loc=nopad_b_start_loc, - nopad_b_seq_len=nopad_b_seq_len, - out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices], - sampling_param_list=[self.sampling_param_list[_i] for _i in indices], - cache_manager=self.cache_manager, - max_total_len=self.max_total_len, - ) - - @classmethod - @torch.no_grad() - def merge(cls, batch1, batch2) -> "InferBatch": - """ - Return megerd new InferBatch - """ - requests = batch1.requests + batch2.requests - requests_idx_mapping = {} - new_batch_size = len(batch1) + len(batch2) - - input_ids = batch1.input_ids.new_empty(new_batch_size) - all_input_ids = [] - input_lengths = [] - out_token_id_counts = [] - sampling_param_list = [] - - cumulative_batch_size = 0 - nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num - nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2.nopad_max_len_in_batch) - max_total_len = max(batch1.max_total_len, batch2.max_total_len) - nopad_b_loc = torch.empty((new_batch_size, batch1.max_total_len + 12), dtype=torch.long, device="cuda") - nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") - nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") - nopad_start_loc_len_temp = 0 - batches = [batch1, batch2] - for i, batch in enumerate(batches): - if i == 0: - requests_idx_mapping = batch.requests_idx_mapping - else: - for k, v in batch.requests_idx_mapping.items(): - requests_idx_mapping[k] = v + cumulative_batch_size - start_index = cumulative_batch_size - end_index = cumulative_batch_size + len(batch) - input_ids[start_index:end_index] = batch.input_ids - nopad_b_seq_len[start_index:end_index] = batch.nopad_b_seq_len - nopad_b_start_loc[start_index:end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp - nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1] - nopad_b_loc[ - start_index:end_index, - nopad_max_len_in_batch - batch.nopad_max_len_in_batch : nopad_max_len_in_batch - 1, - ] = batch.nopad_b_loc[:, : batch.nopad_max_len_in_batch - 1] - - all_input_ids.extend(batch.all_input_ids) - - input_lengths.extend(batch.input_lengths) - out_token_id_counts.extend(batch.out_token_id_counts) - sampling_param_list.extend(batch.sampling_param_list) - # Update - cumulative_batch_size += len(batch) - - nopad_b_loc[:, nopad_max_len_in_batch - 1] = ( - nopad_total_token_num - new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device="cuda") - ) - return InferBatch( - batch_id=batches[0].batch_id, - requests=requests, - requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, - input_lengths=input_lengths, - all_input_ids=all_input_ids, - nopad_total_token_num=nopad_total_token_num, - nopad_max_len_in_batch=nopad_max_len_in_batch, - nopad_b_loc=nopad_b_loc, - nopad_b_start_loc=nopad_b_start_loc, - nopad_b_seq_len=nopad_b_seq_len, - out_token_id_counts=out_token_id_counts, - sampling_param_list=sampling_param_list, - cache_manager=batches[0].cache_manager, - max_total_len=max_total_len, - ) - - def __len__(self): - return len(self.requests) - - def get_post_sample_tensors(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - presence_penalties: List[float] = [] - frequency_penalties: List[float] = [] - temperatures: List[float] = [] - top_ps: List[float] = [] - top_ks: List[int] = [] - p_token_ids: List[int] = [] - p_token_counts: List[int] = [] - p_seq_len: List[int] = [ - 0, - ] - p_max_len_in_batch: int = 0 - for i, id_to_count in enumerate(self.out_token_id_counts): - sample_param = self.sampling_param_list[i] - presence_penalties.append(sample_param.presence_penalty) - frequency_penalties.append(sample_param.frequency_penalty) - temperatures.append(sample_param.temperature) - top_ps.append(sample_param.top_p) - top_ks.append(sample_param.top_k) - - for token_id, count in id_to_count.items(): - p_token_ids.append(token_id) - p_token_counts.append(count) - p_seq_len.append(len(id_to_count)) - p_max_len_in_batch = max(p_max_len_in_batch, len(id_to_count)) - - presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda") - frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda") - temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda") - top_ps = torch.tensor(top_ps, dtype=torch.float, device="cuda") - top_ks = torch.tensor(top_ks, dtype=torch.int32, device="cuda") - p_token_ids = torch.tensor(p_token_ids, dtype=torch.int32, device="cuda") - p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda") - p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda") - p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32) - return ( - presence_penalties, - frequency_penalties, - temperatures, - top_ps, - top_ks, - p_token_ids, - p_token_counts, - p_cumsum_seq_len, - p_max_len_in_batch, - ) diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py deleted file mode 100644 index fc5ecfe5796b..000000000000 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ /dev/null @@ -1,166 +0,0 @@ -# Adapted from https://github.com/ModelTC/lightllm - -from typing import Dict, List, Tuple - -from .sampling_params import SamplingParams - - -class Req: - def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""): - self.request_id = request_id - self.prompt_ids = prompt_ids - self.input_len = len(prompt_ids) - self.max_output_len = sample_params.max_new_tokens - self.sample_params = sample_params - self.output_ids = [] - self.output_metadata_list = [] - self.has_generate_finished = False - self.aborted = False - self.prompts = prompts - - def to_rpc_obj(self): - return { - "request_id": self.request_id, - "input_id": self.prompt_ids, - "output_len": self.max_output_len, - "sampling_param": self.sample_params.to_dict(), - } - - def stop_sequences_matched(self): - # should we add stpp sequences to the sample params? - if self.sample_params.stop_sequences is not None: - for stop_token_ids in self.sample_params.stop_sequences: - stop_len = len(stop_token_ids) - if ( - stop_len > 0 - and len(self.output_ids) >= stop_len - and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)) - ): - return True - return False - - def __repr__(self): - return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, " - - -class Batch: - def __init__(self, batch_id, reqs: List[Req]): - self.batch_id = batch_id - self.reqs = reqs - self.id_to_reqs = {req.request_id: req for req in reqs} - - def input_tokens(self): - batch_input_tokens = 0 - for req in self.reqs: - batch_input_tokens += req.input_len - return batch_input_tokens - - def calcu_max_tokens(self): - tokens = 0 - for req in self.reqs: - tokens += req.input_len + req.max_output_len - return tokens - - def calcu_used_tokens(self): - tokens = 0 - for req in self.reqs: - tokens += req.input_len + len(req.output_ids) - return tokens - - def mark_finished_req(self, eos_id, engine_max_output_len): - has_new_finish = False - for req in self.reqs: - if req.stop_sequences_matched(): - req.has_generate_finished = True - has_new_finish = True - if len(req.output_ids) >= engine_max_output_len: - req.has_generate_finished = True - has_new_finish = True - if req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False: - req.has_generate_finished = True - has_new_finish = True - if len(req.output_ids) >= req.max_output_len or req.aborted: - req.has_generate_finished = True - has_new_finish = True - return has_new_finish - - def filter_finished(self) -> List[Req]: - """ - Filter finished requests from the batch, the finished ones will be removed from 'reqs'. - """ - # TODO: the logic of return should be defined here. - unfinished_req = [] - finished_req = [] - for req in self.reqs: - if not req.has_generate_finished: - unfinished_req.append(req) - else: - finished_req.append(req) - self.reqs = unfinished_req - self.id_to_reqs = {req.request_id: req for req in self.reqs} - return finished_req - - def is_clear(self): - return len(self.reqs) == 0 - - def merge(self, mini_batch): - for _req in mini_batch.reqs: - self.reqs.append(_req) - self.id_to_reqs = {req.request_id: req for req in self.reqs} - return - - def __repr__(self): - return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, " - - def __len__(self): - return len(self.reqs) - - -class BatchTokenIdOut: - def __init__(self): - self.reqs_infs: List[ - Tuple[str, int, Dict, bool, bool] - ] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state] - - -class BatchStrOut: - def __init__(self): - self.reqs_infs: List[ - Tuple[str, str, Dict, bool, bool] - ] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state] - - -class AbortReq: - def __init__(self, req_id): - self.req_id = req_id - - -class RequestOutput: - """The output data of a request to the LLM. - - Args: - request_id: The unique ID of the request. - prompt: The prompt string of the request. - prompt_token_ids: The token IDs of the prompt. - outputs: The output sequences of the request. - """ - - def __init__( - self, - request_id: str, - prompt: str, - prompt_token_ids: List[int], - outputs, - ) -> None: - self.request_id = request_id - self.prompt = prompt - self.prompt_token_ids = prompt_token_ids - self.outputs = outputs - - def __repr__(self) -> str: - return ( - f"RequestOutput(request_id={self.request_id}, " - f"prompt={self.prompt!r}, " - f"prompt_token_ids={self.prompt_token_ids}, " - f"outputs={self.outputs}, " - ) diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py deleted file mode 100644 index 70ef489d3b70..000000000000 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ /dev/null @@ -1,152 +0,0 @@ -import logging -import os -from typing import List - -import ray -import ray.util.collective as collective -import torch -from transformers import AutoModelForCausalLM - -import colossalai -from colossalai.inference.async_manager import start_dynamic_batching -from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer -from colossalai.inference.dynamic_batching.io_struct import RequestOutput -from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass -from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -from colossalai.inference.tensor_parallel.engine import TPInferEngine -from colossalai.shardformer import ShardConfig -from colossalai.testing import free_port - -ray_serve_logger = logging.getLogger("ray.serve") - - -def log_cuda_info(scope_name: str): - ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}") - ray_serve_logger.info( - f" {scope_name}: CUDA_VISIBLE_DEVICES: {os.getenv('CUDA_VISIBLE_DEVICES', 'NO DEVICES FOUND!')}" - ) - if torch.cuda.is_available(): - ray_serve_logger.info( - f" {scope_name}: cuda current_device: {torch.cuda.current_device()}, cuda device count: {torch.cuda.device_count()}" - ) - else: - ray_serve_logger.info(f" {scope_name}: cuda is not available!") - - -@ray.remote(num_gpus=1) -class Worker: - def __init__( - self, - model_path: str, - tensor_parallel_size: int, - max_batch_size: int, - max_input_len: int, - max_output_len: int, - router_config: RooterArgsClass, - ): - log_cuda_info("Worker.init") - self.tensor_parallel_size = tensor_parallel_size - self.model_path = model_path - self.max_batch_size = max_batch_size - self.max_input_len = max_input_len - self.max_output_len = max_output_len - self.router_config = router_config - - def setup(self, world_size, rank, port): - # initialize a ray collective group, otherwise colossalai distributed env won't be built successfully - collective.init_collective_group(world_size, rank, "nccl", "default") - # initialize and set distributed environment - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..") - log_cuda_info("Worker.setup") - - # Load model - self.tokenizer = get_tokenizer(tokenizer_name=self.model_path) - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = self.tokenizer.eos_token - self.model = AutoModelForCausalLM.from_pretrained( - self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16 - ) - shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True) - self.infer_engine = TPInferEngine( - self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len - ) - self.start_dynamic_batching = start_dynamic_batching(self.router_config, self.infer_engine, []) - - return True - - # def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> List[str]: - # ray_serve_logger.info(f"text: {prompt}") - - # final_outputs = self.start_dynamic_batching.generate(prompt, sampling_params, request_id) - - # return final_outputs - - def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): - self.start_dynamic_batching.add_input(request_id, prompt, sampling_params) - - def abort(self, request_id: str): - self.start_dynamic_batching.abort(request_id) - - def step(self) -> List[RequestOutput]: - return self.start_dynamic_batching._step() - - def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): - self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt) - - def is_running(self): - return self.start_dynamic_batching.is_running() - - -class Driver: - def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClass): - log_cuda_info("Driver:init") - model_path = engine_config.model - tensor_parallel_size = engine_config.tensor_parallel_size - - self.num_workers = tensor_parallel_size - self.workers = [] - init_rets = [] - - # Just grab a free port on localhost - # NOTE workers in this communication group listen to the same port - available_port = free_port() - - for i in range(self.num_workers): - worker_name = "worker_idx_{}".format(i) - w = Worker.options(name=worker_name).remote( - model_path, - self.num_workers, - engine_config.max_batch_size, - engine_config.max_input_len, - engine_config.max_output_len, - router_config, - ) - self.workers.append(w) - init_rets.append(w.setup.remote(self.num_workers, i, available_port)) - _options = { - "group_name": "default_driver", - "world_size": self.num_workers, - "ranks": [i for i in range(self.num_workers)], - "backend": "nccl", - } - collective.create_collective_group(self.workers, **_options) - _ = ray.get(init_rets) - - def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): - ray.get([w.add_input.remote(request_id, prompt, sampling_params) for w in self.workers]) - - def abort(self, request_id: str): - ray.get([w.abort.remote(request_id) for w in self.workers]) - - def step(self): - results = ray.get([w.step.remote() for w in self.workers]) - outputs = results[0] # get any one of the copies - return outputs - - def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompt: str): - ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) - - def is_running(self): - results = ray.get([w.is_running.remote() for w in self.workers]) - return any(results) diff --git a/colossalai/inference/dynamic_batching/ray_init_config.py b/colossalai/inference/dynamic_batching/ray_init_config.py deleted file mode 100644 index 471f07330aec..000000000000 --- a/colossalai/inference/dynamic_batching/ray_init_config.py +++ /dev/null @@ -1,58 +0,0 @@ -import logging - -import yaml -from pydantic import BaseModel - -logger = logging.getLogger(__name__) - - -class EngineArgsClass(BaseModel): - """Config for Engine""" - - model: str - tensor_parallel_size: int = 2 - max_batch_size: int = 4 - max_input_len: int = 128 - max_output_len: int = 32 - - -class RooterArgsClass(BaseModel): - """Config for Rooter""" - - max_total_token_num: int = 42 - batch_max_tokens: int = 42 - eos_id: int = 0 - disable_log_stats: bool = False - log_stats_interval: int = 10 - model: str - - -class RayInitConfig(BaseModel): - """All-together configs without app router config""" - - engine_config_data: EngineArgsClass - router_config_data: RooterArgsClass - - @classmethod - def from_yaml_path(cls, path: str): - try: - with open(path, "r") as yaml_file: - try: - config = yaml.safe_load(yaml_file) - # serve deployment config - engine_config = config.get("engine_config", {}) - router_config = config.get("router_config", {}) - - return cls( - engine_config_data=engine_config, - router_config_data=router_config, - ) - except yaml.YAMLError as e: - logger.error(f"An Error occurred when parsing yaml: {e}") - raise - except FileNotFoundError: - logger.error(f"The file '{path}' does not exist!") - raise - except OSError as e: - logger.error(f"An Error occurred: {e}") - raise diff --git a/colossalai/inference/dynamic_batching/req_queue.py b/colossalai/inference/dynamic_batching/req_queue.py deleted file mode 100644 index 0de43bd1a21f..000000000000 --- a/colossalai/inference/dynamic_batching/req_queue.py +++ /dev/null @@ -1,73 +0,0 @@ -# Adapted from https://github.com/ModelTC/lightllm - -import uuid -from typing import List - -import numpy as np - -from .io_struct import Batch, Req - - -class ReqQueue: - def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size, waiting_req_list=[]) -> None: - self.max_total_tokens = max_total_tokens - assert batch_max_tokens is not None - self.batch_max_tokens = batch_max_tokens - self.running_max_req_size = running_max_req_size - self.waiting_req_list: List[Req] = waiting_req_list - - def append(self, req): - self.waiting_req_list.append(req) - return - - def _init_cache_list(self, current_batch: Batch): - if current_batch is not None: - self.cache_len_list = [ - (req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1) - for req in current_batch.reqs - ] - else: - self.cache_len_list = [] - - # @calculate_time(show=True, min_cost_ms=0.1) - def _can_add_new_req(self, req): - self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis - self.cache_len_list.sort(key=lambda x: -x[1]) - - left_out_len_array = np.array([e[1] for e in self.cache_len_list]) - # assert left_out_len_array.min() >= 0 - has_run_len_array = np.array([e[0] for e in self.cache_len_list]) - cum_run_len_array = np.cumsum(has_run_len_array) - size_array = np.arange(1, len(self.cache_len_list) + 1, 1) - - need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() - # NOTE: change here < to <= - return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size - - def generate_new_batch(self, current_batch: Batch = None): - if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size: - return None - self._init_cache_list(current_batch) - can_run_list = [] - new_batch_total_tokens = 0 - aborted_count = 0 - for req in self.waiting_req_list: - flag = self._can_add_new_req(req) - if req.aborted: - aborted_count += 1 - continue - if flag and new_batch_total_tokens + req.input_len <= self.batch_max_tokens: - can_run_list.append(req) - new_batch_total_tokens += req.input_len - else: - break - - if len(can_run_list) != 0: - new_batch = Batch(uuid.uuid4().hex, can_run_list) - self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] - return new_batch - else: - return None - - def __len__(self): - return self.waiting_req_list.__len__() diff --git a/colossalai/inference/dynamic_batching/sampling_params.py b/colossalai/inference/dynamic_batching/sampling_params.py deleted file mode 100644 index a37a83390021..000000000000 --- a/colossalai/inference/dynamic_batching/sampling_params.py +++ /dev/null @@ -1,83 +0,0 @@ -# Adapted from https://github.com/ModelTC/lightllm - -"""Sampling parameters for text generation.""" -from typing import List, Optional, Union - -_SAMPLING_EPS = 1e-5 - - -class SamplingParams: - def __init__( - self, - do_sample: bool = False, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = -1, # -1 is for all - ignore_eos: bool = False, - max_new_tokens: int = 256, - stop_sequences: Optional[Union[str, List[str]]] = None, # conditions to stop generation - ) -> None: - self.do_sample = do_sample - self.presence_penalty = presence_penalty - self.frequency_penalty = frequency_penalty - self.temperature = temperature - self.top_p = top_p - self.top_k = top_k - self.ignore_eos = ignore_eos - self.max_new_tokens = max_new_tokens - self.stop_sequences = stop_sequences - if self.do_sample == False: - self.temperature = 1.0 - self.top_p = 1.0 - self.top_k = 1 - if ( - self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS - ): # temperature is too slow, change to greedy search - self.temperature = 1.0 - self.top_k = 1 - return - - def verify(self): - if self.presence_penalty < 0.0: - raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}") - if self.frequency_penalty < 0.0: - raise ValueError(f"frequency_penalty must >= 0.0, got {self.frequency_penalty}") - if self.temperature <= 0.0: - raise ValueError(f"temperature must > 0.0, got {self.temperature}") - if self.top_p <= 0.0 or self.top_p > 1.0: - raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}") - if self.top_k < -1 or self.top_k == 0: - raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.") - if self.max_new_tokens < 1: - raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.") - return - - def stop_sentences_to_token_ids(self, tokenizer): - if self.stop_sequences is None: - self.stop_sequences = [] - else: - if isinstance(self.stop_sequences, str): - self.stop_sequences = [self.stop_sequences] - new_stop_sequences = [] - for stop_str in self.stop_sequences: - stop_str_ids = tokenizer.encode(stop_str) - if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id - stop_str_ids = stop_str_ids[1:] - if len(stop_str_ids) > 0: - new_stop_sequences.append(stop_str_ids) - self.stop_sequences = new_stop_sequences - return - - def to_dict(self): - ret = {} - ret["do_sample"] = self.do_sample - ret["presence_penalty"] = self.presence_penalty - ret["frequency_penalty"] = self.frequency_penalty - ret["temperature"] = self.temperature - ret["top_p"] = self.top_p - ret["top_k"] = self.top_k - # if self.ignore_eos is not None: - # ret["ignore_eos"] = self.ignore_eos - return ret diff --git a/colossalai/inference/dynamic_batching/stats.py b/colossalai/inference/dynamic_batching/stats.py deleted file mode 100644 index 524072861a3f..000000000000 --- a/colossalai/inference/dynamic_batching/stats.py +++ /dev/null @@ -1,45 +0,0 @@ -# Adapted from https://github.com/ModelTC/lightllm - -import time - - -class Stats: - def __init__(self, log_status, log_stats_interval) -> None: - self.log_stats = log_status - self.log_stats_interval = log_stats_interval - self.last_log_time = time.time() - self.all_tokens = 0 - self.output_tokens = 0 - self.prompt_tokens = 0 - return - - def count_prompt_tokens(self, run_batch): - if self.log_stats: - tokens = run_batch.input_tokens() - self.prompt_tokens += tokens - self.all_tokens += tokens - return - - def count_output_tokens(self, run_batch): - if self.log_stats: - tokens = len(run_batch.reqs) - self.output_tokens += tokens - self.all_tokens += tokens - return - - def print_stats(self): - if not self.log_stats: - return - - now = time.time() - if now - self.last_log_time > self.log_stats_interval: - print( - f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n" - f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n" - f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s" - ) - self.all_tokens = 0 - self.output_tokens = 0 - self.prompt_tokens = 0 - self.last_log_time = now - return diff --git a/colossalai/inference/hybridengine/engine.py b/colossalai/inference/hybridengine/engine.py index 5e944014b565..9248d45ff1c3 100644 --- a/colossalai/inference/hybridengine/engine.py +++ b/colossalai/inference/hybridengine/engine.py @@ -9,8 +9,8 @@ from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.base_policy import Policy -from ..pipeline.microbatch_manager import MicroBatchManager -from ..tensor_parallel.kvcache_manager import MemoryManager +from ..kvcache_manager import MemoryManager +from .microbatch_manager import MicroBatchManager PP_AXIS, TP_AXIS = 0, 1 diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/hybridengine/microbatch_manager.py similarity index 98% rename from colossalai/inference/pipeline/microbatch_manager.py rename to colossalai/inference/hybridengine/microbatch_manager.py index 441cf603985c..bf50a160354e 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/hybridengine/microbatch_manager.py @@ -3,8 +3,7 @@ import torch -from ..tensor_parallel.batch_infer_state import BatchInferState -from ..tensor_parallel.kvcache_manager import MemoryManager +from ..kvcache_manager import BatchInferState, MemoryManager __all__ = "MicroBatchManager" diff --git a/colossalai/inference/hybridengine/modeling/bloom.py b/colossalai/inference/hybridengine/modeling/bloom.py index 45bd9f39bf0d..d2276caed4d0 100644 --- a/colossalai/inference/hybridengine/modeling/bloom.py +++ b/colossalai/inference/hybridengine/modeling/bloom.py @@ -14,7 +14,7 @@ ) from transformers.utils import logging -from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.inference.kvcache_manager.batch_infer_state import BatchInferState from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd from colossalai.pipeline.stage_manager import PipelineStageManager diff --git a/colossalai/inference/hybridengine/modeling/llama.py b/colossalai/inference/hybridengine/modeling/llama.py index d05c0af2214b..1719113664e7 100644 --- a/colossalai/inference/hybridengine/modeling/llama.py +++ b/colossalai/inference/hybridengine/modeling/llama.py @@ -6,7 +6,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel from transformers.utils import logging -from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.inference.kvcache_manager.batch_infer_state import BatchInferState from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from colossalai.pipeline.stage_manager import PipelineStageManager diff --git a/colossalai/inference/kvcache_manager/__init__.py b/colossalai/inference/kvcache_manager/__init__.py new file mode 100644 index 000000000000..5b6ca182efae --- /dev/null +++ b/colossalai/inference/kvcache_manager/__init__.py @@ -0,0 +1,2 @@ +from .batch_infer_state import BatchInferState +from .kvcache_manager import MemoryManager diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/kvcache_manager/batch_infer_state.py similarity index 100% rename from colossalai/inference/tensor_parallel/batch_infer_state.py rename to colossalai/inference/kvcache_manager/batch_infer_state.py diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/kvcache_manager/kvcache_manager.py similarity index 100% rename from colossalai/inference/tensor_parallel/kvcache_manager.py rename to colossalai/inference/kvcache_manager/kvcache_manager.py diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py deleted file mode 100644 index 9672a50141a0..000000000000 --- a/colossalai/inference/manager.py +++ /dev/null @@ -1,296 +0,0 @@ -# Adapted from https://github.com/ModelTC/lightllm - -import time -from typing import List - -from .dynamic_batching.get_tokenizer import get_tokenizer -from .dynamic_batching.infer_batch import InferBatch -from .dynamic_batching.io_struct import Batch, Req -from .dynamic_batching.req_queue import ReqQueue -from .dynamic_batching.sampling_params import SamplingParams -from .dynamic_batching.stats import Stats -from .tensor_parallel import TPInferEngine - - -class DynamicBatchManager: - def __init__( - self, - tp_engine: TPInferEngine, - max_total_token_num, - batch_max_tokens, - model, - tokenizer=None, - eos_id=None, - log_stats=True, - log_stats_interval=10, - running_batch: Batch = None, - waiting_req_list: List = [], - ): - """ - Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager - max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) - batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests - running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine - eos_id : The end token of a seq - model: the model weight dir path, the app will load config, weights and tokenizer from this dir - log_stats : whether to log stats - log_stats_interval : log stats interval - running_batch : running batch - waiting_req_list : list of waiting requests, initialized before dynamic batch manager - """ - self.engine = tp_engine - self.max_total_token_num = max_total_token_num - running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2 - self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list) - # all the inputs should be put into req_queue: waiting req list - assert max_total_token_num >= self.engine.max_batch_size * ( - self.engine.max_input_len + self.engine.max_output_len - ), "max_total_token_num should be greater than max_batch_size * (max_input_len+max_output_len)" - assert ( - batch_max_tokens >= self.engine.max_input_len + self.engine.max_output_len - ), "batch_max_tokens should be greater than (max_input_len+max_output_len)" - self.running_batch: Batch = running_batch - self.eos_id = eos_id - self.has_wait_tokens = 0 - self.max_wait_tokens = 10 - self.model = model - - self.stats_tool = Stats(log_stats, log_stats_interval) - self.mem_usage_interval = log_stats_interval * 2 - self.tokenizer = get_tokenizer(tokenizer_name=self.model) if tokenizer is None else tokenizer - if self.eos_id == None: - self.eos_id = self.tokenizer.eos_token_id - - def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""): - """ - Add new request to req queue, during initialization all requests are held in waiting list. - """ - sampling_params.max_new_tokens = ( - self.engine.max_output_len - if sampling_params.max_new_tokens > self.engine.max_output_len - else sampling_params.max_new_tokens - ) - req = Req(request_id, prompt_ids, sampling_params, prompts) - self.req_queue.append(req) - return - - def add_input(self, request_id, prompts, sampling_params): - """ - Encode and Add new input to req queue. support one sequence input for now. - """ - prompt_ids = self.tokenizer.encode(prompts) - prompt_len = len(prompt_ids) - if prompt_len > self.engine.max_input_len: - raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}") - sampling_params.stop_sentences_to_token_ids(self.tokenizer) - self.add_req(request_id, prompt_ids, sampling_params, prompts) - return - - def abort(self, request_id): - if self.running_batch is not None: - for req in self.running_batch.reqs: - if req.request_id == request_id: - req.has_generate_finished = True - req.aborted = True - for req in self.req_queue.waiting_req_list: - if req.request_id == request_id: - req.has_generate_finished = True - req.aborted = True - return - - def loop_for_fwd(self): - """ - The main loop for a dynamic batching process. - """ - counter_count = 0 - # self.running_batch is not None or self.req_queue.waiting_req_list - while self.running_batch is not None or self.req_queue.waiting_req_list: - yield from self._step() - counter_count += 1 - if self.running_batch is not None: - if counter_count % self.mem_usage_interval == 0: - print( - "current batch size:", - len(self.running_batch.reqs), - "token used ratio:", - self.running_batch.calcu_used_tokens() / self.max_total_token_num, - ) - self.stats_tool.print_stats() - - if self.running_batch is None: - time.sleep(0.1) # 10ms - - def _step(self): - """ - Logic for handling requests - """ - - if self.running_batch is None: - new_batch = self.req_queue.generate_new_batch(self.running_batch) - if new_batch is not None: - self.stats_tool.count_prompt_tokens(new_batch) - self.running_batch = new_batch - yield from self._prefill_batch(self.running_batch) - self._filter_runing_batch() - self.has_wait_tokens = 0 - return - - if self.has_wait_tokens < self.max_wait_tokens: - self.stats_tool.count_output_tokens(self.running_batch) - yield from self._decode_batch(self.running_batch) - self._filter_runing_batch() - self.has_wait_tokens += 1 - return - else: - new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) - if new_mini_batch is not None: - self.stats_tool.count_prompt_tokens(new_mini_batch) - yield from self._prefill_batch(new_mini_batch) - if not new_mini_batch.is_clear(): - self._merge_batch(self.running_batch, new_mini_batch) - self.running_batch.merge(new_mini_batch) - self.has_wait_tokens = 0 - - else: - self.stats_tool.count_output_tokens(self.running_batch) - yield from self._decode_batch(self.running_batch) - self._filter_runing_batch() - self.has_wait_tokens += 1 - - return - - def _init_batch(self, batch: Batch, dtype="fp16"): - reqs = [r.to_rpc_obj() for r in batch.reqs] - batch_id = batch.batch_id - - import torch - - if dtype == "fp16": - dtype = torch.float16 - else: - assert False, "error dtype" - - batch_data = InferBatch.init_batch( - batch_id, - reqs, - dtype, - torch.cuda.current_device(), - self.engine.cache_manager, - self.engine.model.config.vocab_size, - self.engine.max_input_len + self.engine.max_output_len, - ) - self.engine.cache[batch_id] = batch_data - - def _prefill_batch(self, batch): - """ - For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. - """ - self._init_batch(batch) - - # TODO: figure out if cache and batch id is needed - ans = self.engine._prefill_batch(batch.batch_id) - req_to_out_token_id = ans - self._add_token_id_to_req(batch, req_to_out_token_id) - has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) - yield from self._handle_finish_req(batch, has_new_finished_req) - - # delete finished reqs - - def _decode_batch(self, batch: Batch): - """ - Decoding process - """ - ans = self.engine._decode_batch(batch.batch_id) - req_to_out_token_id = ans - self._add_token_id_to_req(batch, req_to_out_token_id) - has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) - yield from self._handle_finish_req(batch, has_new_finished_req) - - def _filter_batch(self, batch: Batch): - batch_id = batch.batch_id - req_id_list = [r.request_id for r in batch.reqs] - batch = self.engine.cache.pop(batch_id) - filter_batch = batch.filter(req_id_list) - del batch - self.engine.cache[batch_id] = filter_batch - - def _merge_batch(self, batch1, batch2): - """ - Merge new mini batch into running batch. - """ - batch1 = self.engine.cache.pop(batch1.batch_id) - batch2 = self.engine.cache.pop(batch2.batch_id) - - m_batch = InferBatch.merge(batch1, batch2) - self.engine.cache[batch1.batch_id] = m_batch - del batch1 - del batch2 - - def _remove_batch(self, batch): - """ - Remove finished batch. - """ - batch = self.engine.cache.pop(batch.batch_id) - batch.free_self() - del batch - - def _handle_finish_req(self, batch: Batch, has_new_finished_req): - if has_new_finished_req: - finished_reqs = batch.filter_finished() - if batch.is_clear(): - self._remove_batch(batch) - else: - self._filter_batch(batch) - yield from self._output_process(finished_reqs) - - def _filter_runing_batch(self): - if self.running_batch is not None and self.running_batch.is_clear(): - self.running_batch = None - - def _add_token_id_to_req(self, batch: Batch, req_ans): - for req_id, (new_token_id, new_gen_metadata) in req_ans.items(): - req = batch.id_to_reqs[req_id] - req.output_ids.append(new_token_id) - req.output_metadata_list.append(new_gen_metadata) - return - - def _output_process(self, finished_reqs: List[Req]): - """ - Process the output of a batch. - """ - for req in finished_reqs: - output = self.tokenizer.decode(req.output_ids) - yield req.prompts + output - - def clean_up(self): - # this logic should be implemented in the future. - pass - - def generate(self, request_id, prompts, sampling_params): - """ - Generate the output of a request. - """ - self.add_input(request_id, prompts, sampling_params) - return self.loop_for_fwd() - - def is_running(self): - return self.running_batch is not None or self.req_queue.waiting_req_list - - -def start_dynamic_batching(args, tp_engine, waiting_req_list): - try: - batch_manager = DynamicBatchManager( - tp_engine=tp_engine, - max_total_token_num=args.max_total_token_num, - batch_max_tokens=args.batch_max_tokens, - eos_id=args.eos_id, - model=args.model, - log_stats=not args.disable_log_stats, - log_stats_interval=args.log_stats_interval, - waiting_req_list=waiting_req_list, - ) - - except Exception: - raise Exception - - return batch_manager diff --git a/colossalai/inference/pipeline/README.md b/colossalai/inference/pipeline/README.md deleted file mode 100644 index f9bb35cc4d4c..000000000000 --- a/colossalai/inference/pipeline/README.md +++ /dev/null @@ -1,83 +0,0 @@ -# 🐳 Pipeline Inference - -## Table of Contents -- [💡 Introduction](#introduction) -- [🔗 Design](#design) -- [🔨 Usage](#usage) - - [Example](#example) - - [Quick start](#quick-start) -- [📊 Performance](#performance) - -## Introduction - -`Pipeline Inference` is a module designed to make inference on a pipeline way. In inference systems, although there is no need to store intermediate information such as activations during forward propagation for backward propagation, the weights of some larger models still cannot fit on a single GPU for inference. This requires us to use model parallelism and other methods to reduce the memory occupation on a single GPU. Pipeline parallelism, as one of the traditional model parallelism approaches, has been widely used due to its reduced all-reduce communication requirements and simple layout. The main issue with pipeline parallelism, known as bubbles, can be almost eliminated in inference because the backward propagation that causes bubbles no longer exists in inference. This makes pipeline parallelism almost bubble-free in the ideal scenario where the sequence length is the same across the pipeline. - -## Design - -Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManager` and `generate` [schedule](https://github.com/hpcaitech/ColossalAI/blob/feature/pipeline-infer/colossalai/pipeline/schedule/generate.py). - -1. `PPInderEngine` is the High-Level API for users to use. It is responsible for the following tasks: - - Initialize the pipeline inference environment with `PipelineStageManager` and model with `ShardFormer`. - - Run the pipeline inference model. - -2. `MicroBatchManager` is a structure to manage the micro-batch information. It is responsible for the following tasks: - - Record each micro-batch information, like generated new tokens and kvcache. - - Record each micro-batch inference state, like prefill, generate or done. - - Update the micro-batch information. - -3. `generate` schedule implements the simple pipeline inference layout. When pipeline size is 2, we use `torch.distributed.P2Pop` to implement the communication between stages, mainly to solve the race communication. When pipeline size is larger than 2, we use `torch.distributed.broadcast` which is faster than `torch.distributed.P2Pop`. - -## Usage - -### Example -```python -from colossalai.inference import PPInferEngine -from colossalai.inference.pipeline.policies import LlamaModelInferPolicy -import colossalai -from transformers import LlamaForCausalLM, LlamaTokenizer - -colossalai.launch_from_torch(config={}) - -model = LlamaForCausalLM.from_pretrained("/path/to/model") -tokenizer = LlamaTokenizer.from_pretrained("/path/to/model") - -# assume the model is inferred with 2 pipeline stages -inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=32) - -input = ["Introduce a landmark in London","Introduce a landmark in Singapore"] -data = tokenizer(input, return_tensors='pt') -output = inferengine.inference(data.to('cuda')) -print(tokenizer.batch_decode(output)) -``` - -## Performance - -We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2 * A10, 20G / 2 * A800, 80G. - -### Llama Throughput (tokens/s) | input length=1024, output length=128 - -#### A10 7b, fp16 -| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)| -| :---: | :---: | :---: | :---: | :---: | :---: | :---:| -| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM | -| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM | - -#### A10 13b, fp16 -| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) | -| :---: | :---: | :---: | :---: | :---: | -| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 | -| Hugging Face | 23.48 | 37.59 | 53.44 | OOM | - - -#### A800 7b, fp16 -| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | -| :---: | :---: | :---: | :---: | :---: | :---: | -| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 | -| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 | - - -#### A800 13b, fp16 -| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | -| :---: | :---: | :---: | :---: | :---: | :---: | -| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 | -| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 | diff --git a/colossalai/inference/pipeline/__init__.py b/colossalai/inference/pipeline/__init__.py deleted file mode 100644 index f43e4a847448..000000000000 --- a/colossalai/inference/pipeline/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .microbatch_manager import MicroBatchManager - -__all__ = ["MicroBatchManager"] diff --git a/colossalai/inference/tensor_parallel/__init__.py b/colossalai/inference/tensor_parallel/__init__.py deleted file mode 100644 index 112b920ba158..000000000000 --- a/colossalai/inference/tensor_parallel/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .engine import TPInferEngine -from .kvcache_manager import MemoryManager - -__all__ = ["MemoryManager", "TPInferEngine"] diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py deleted file mode 100644 index 2478b574d307..000000000000 --- a/colossalai/inference/tensor_parallel/engine.py +++ /dev/null @@ -1,480 +0,0 @@ -from typing import Any, Callable, List, Optional, Union - -import torch -import torch.nn as nn -from transformers import BloomForCausalLM, LlamaForCausalLM -from transformers.generation import GenerationConfig -from transformers.generation.stopping_criteria import StoppingCriteriaList -from transformers.tokenization_utils_base import BatchEncoding - -from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer.policies.auto_policy import get_autopolicy - -from .batch_infer_state import BatchInferState -from .kvcache_manager import MemoryManager - -# from dynamic_batching.infer_batch import InferBatch - -DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 - -_supported_models = [ - "LlamaForCausalLM", - "LlamaModel", - "BloomForCausalLM", - "ChatGLMModel", - "ChatGLMForConditionalGeneration", - "LlamaGPTQForCausalLM", - "BloomGPTQForCausalLM", -] - - -class TPInferEngine: - """Engine class for tensor parallel inference. - - Args: - model (Module): original model, e.g. huggingface CausalLM - shard_config (ShardConfig): The config for sharding original model - max_batch_size (int): maximum batch size - max_input_len (int): maximum input length of sequence - max_output_len (int): maximum output length of output tokens - dtype (torch.dtype): datatype used to init KV cache space - device (str): device the KV cache of engine to be initialized on - - Examples: - >>> # define model and shard config for your inference - >>> model = ... - >>> generate_kwargs = ... - >>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) - >>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - >>> outputs = infer_engine.generate(input_ids, **generate_kwargs) - """ - - def __init__( - self, - model: nn.Module, - shard_config: ShardConfig, - max_batch_size: int, - max_input_len: int, - max_output_len: int, - dtype: torch.dtype = torch.float16, - device: str = "cuda", - ) -> None: - self.max_batch_size = max_batch_size - self.max_input_len = max_input_len - self.max_output_len = max_output_len - self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len) - # Constraints relatable with specs of devices and model - # This may change into an optional arg in the future - assert self.max_batch_size <= 64, "Max batch size exceeds the constraint" - assert self.max_input_len + self.max_output_len <= 4096, "Max length exceeds the constraint" - - self.dtype = dtype - - self.head_dim = model.config.hidden_size // model.config.num_attention_heads - self.head_num = model.config.num_attention_heads - num_hidden_layers = ( - model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers - ) - self.layer_num = num_hidden_layers - - self.multi_query_group_num = model.config.num_attention_heads - # default to attention_heads - if hasattr(model.config, "multi_query_attention"): - self.multi_query_attention = getattr(model.config, "multi_query_attention") - - if hasattr(model.config, "multi_query_group_num"): - self.multi_query_group_num = getattr(model.config, "multi_query_group_num") - - if hasattr(model.config, "num_key_value_heads"): - self.multi_query_group_num = getattr(model.config, "num_key_value_heads") - - self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config - self.cache_manager = None - - self.max_dq_buffer_size = 1 - self.max_inner_outer_dim = 1 - self.gptq_temp_state_buffer = None - self.gptq_temp_dq_buffer = None - self.bits = -1 - self.use_act_order = False - - self.shard_config = shard_config - self.model = None - self.cache = {} - - # optimize the original model by sharding with ShardFormer - self._optimize_model(model=model.to(device)) - - def _init_manager(self) -> None: - assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig" - assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" - self.head_num //= self.tp_size # update sharded number of heads - - if hasattr(self, "multi_query_attention"): - # NOTE the logic of MQA tensor parallelism should be specified. - assert ( - self.multi_query_group_num % self.tp_size == 0 - ), f"Cannot shard {self.multi_query_group_num} query groups with tp size {self.tp_size}" - self.cache_manager = MemoryManager( - self.max_total_token_num, - self.dtype, - self.multi_query_group_num // self.tp_size, - self.head_dim, - self.layer_num, - ) - else: - self.cache_manager = MemoryManager( - self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num - ) - - def _post_init_gptq_buffer(self, model: nn.Module) -> None: - from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear - - HAS_GPTQ_CUDA = False - try: - from colossalai.kernel.op_builder.gptq import GPTQBuilder - - gptq_cuda = GPTQBuilder().load() - HAS_GPTQ_CUDA = True - except ImportError: - warnings.warn("CUDA gptq is not installed") - HAS_GPTQ_CUDA = False - - for name, submodule in model.named_modules(): - if isinstance(submodule, CaiQuantLinear): - self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8) - - if self.use_act_order: - self.max_inner_outer_dim = max( - self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures - ) - self.bits = submodule.bits - if not (HAS_GPTQ_CUDA and self.bits == 4): - return - - max_input_len = 1 - if self.use_act_order: - max_input_len = self.max_input_len - # The temp_state buffer is required to reorder X in the act-order case. - # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - self.gptq_temp_state_buffer = torch.zeros( - (max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device() - ) - self.gptq_temp_dq_buffer = torch.zeros( - (1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device() - ) - - gptq_cuda.prepare_buffers( - torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer - ) - # Using the default from exllama repo here. - matmul_recons_thd = 8 - matmul_fused_remap = False - matmul_no_half2 = False - gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) - - torch.cuda.empty_cache() - - def _optimize_model(self, model: nn.Module) -> None: - """ - Optimize the original model by sharding with ShardFormer. - In further generation, use the sharded model instead of original model. - """ - # NOTE we will change to use an inference config later with additional attrs we want - assert self.shard_config.inference_only is True - shardformer = ShardFormer(shard_config=self.shard_config) - self._prepare_with_shard_config(shard_config=self.shard_config) - self._shard_model_by(shardformer, model) - - def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig: - """Prepare the engine with a given ShardConfig. - - Args: - shard_config (ShardConfig): shard config given to specify settings of the engine. - If not provided, a default ShardConfig with tp size 1 will be created. - """ - self.tp_size = 1 - if shard_config is None: - shard_config = ShardConfig( - tensor_parallel_process_group=None, - pipeline_stage_manager=None, - enable_tensor_parallelism=False, - enable_fused_normalization=False, - enable_all_optimization=False, - enable_flash_attention=False, - enable_jit_fused=False, - inference_only=True, - ) - else: - shard_config.inference_only = True - shard_config.pipeline_stage_manager = None - if shard_config.enable_tensor_parallelism: - self.tp_size = shard_config.tensor_parallel_size - self._init_manager() - - return shard_config - - def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: - """Shard original model by the given ShardFormer and store the sharded model.""" - assert ( - self.tp_size == shardformer.shard_config.tensor_parallel_size - ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" - model_name = model.__class__.__name__ - assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." - - model = model.model if self.shard_config.inference_gptq else model - policy = get_autopolicy(model, shard_config=self.shard_config) - - self.model, _ = shardformer.optimize(model, policy) - - if self.shard_config.inference_gptq: - self._post_init_gptq_buffer(self.model) - - self.model = self.model.cuda() - - @property - def supported_models(self) -> List[str]: - return _supported_models - - def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], **generate_kwargs) -> torch.Tensor: - """Generate token sequence. - - Args: - input_tokens: could be one of the following types - 1. BatchEncoding or dict (e.g. tokenizer batch_encode) - 2. list of input token ids (e.g. appended result of tokenizer encode) - 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') - Returns: - torch.Tensor: The returned sequence is given inputs + generated_tokens. - """ - if isinstance(input_tokens, torch.Tensor): - input_tokens = dict(input_ids=input_tokens, attention_mask=torch.ones_like(input_tokens, dtype=torch.bool)) - for t in input_tokens: - if torch.is_tensor(input_tokens[t]): - input_tokens[t] = input_tokens[t].cuda() - if "max_new_tokens" not in generate_kwargs: - generate_kwargs.update(max_new_tokens=self.max_output_len) - - return self._generate_by_set_infer_state(input_tokens, **generate_kwargs) - - def prepare_batch_state(self, inputs) -> BatchInferState: - """ - Create and prepare BatchInferState used for inference during model forwrad, - by processing each sequence of the given inputs. - - Args: - inputs: should be one of the following types - 1. BatchEncoding or dict (e.g. tokenizer batch_encode) - 2. list of input token ids (e.g. appended result of tokenizer encode) - 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') - NOTE For torch.Tensor inputs representing a batch of inputs, we are unable to retrieve - the actual length (e.g. number of tokens) of each input without attention mask - Hence, for torch.Tensor with shape [bs, l] where bs > 1, we will assume - all the inputs in the batch has the maximum length l - Returns: - BatchInferState: the states for the current batch during inference - """ - if not isinstance(inputs, (BatchEncoding, dict, list, torch.Tensor)): - raise TypeError(f"inputs type {type(inputs)} is not supported in prepare_batch_state") - - input_ids_list = None - attention_mask = None - - if isinstance(inputs, (BatchEncoding, dict)): - input_ids_list = inputs["input_ids"] - attention_mask = inputs["attention_mask"] - else: - input_ids_list = inputs - if isinstance(input_ids_list[0], int): # for a single input - input_ids_list = [input_ids_list] - attention_mask = [attention_mask] if attention_mask is not None else attention_mask - - batch_size = len(input_ids_list) - seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - start_index = 0 - - max_len_in_batch = -1 - if isinstance(inputs, (BatchEncoding, dict)): - for i, attn_mask in enumerate(attention_mask): - curr_seq_len = len(attn_mask) - # if isinstance(attn_mask, torch.Tensor): - # curr_seq_len = int(torch.sum(attn_mask)) - # else: - # curr_seq_len = int(sum(attn_mask)) - seq_lengths[i] = curr_seq_len - seq_start_indexes[i] = start_index - start_index += curr_seq_len - max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch - else: - length = max(len(input_id) for input_id in input_ids_list) - for i, input_ids in enumerate(input_ids_list): - curr_seq_len = length - seq_lengths[i] = curr_seq_len - seq_start_indexes[i] = start_index - start_index += curr_seq_len - max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch - - block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda") - batch_infer_state = BatchInferState(batch_size, max_len_in_batch) - batch_infer_state.seq_len = seq_lengths.to("cuda") - batch_infer_state.start_loc = seq_start_indexes.to("cuda") - batch_infer_state.block_loc = block_loc - batch_infer_state.decode_layer_id = 0 - batch_infer_state.past_key_values_len = 0 - batch_infer_state.is_context_stage = True - batch_infer_state.set_cache_manager(self.cache_manager) - - return batch_infer_state - - @torch.no_grad() - def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch.Tensor: - """ - Generate output tokens by setting BatchInferState as an attribute to the model and calling model.generate - - Args: - inputs: should be one of the following types - 1. BatchEncoding or dict (e.g. tokenizer batch_encode) - 2. list of input token ids (e.g. appended result of tokenizer encode) - 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') - """ - - # for testing, always use sharded model - assert self.model is not None, "sharded model does not exist" - - batch_infer_state = self.prepare_batch_state(input_tokens) - assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit" - - # set BatchInferState for the current batch as attr to model - # NOTE this is not a preferable way to pass BatchInferState during inference - # we might want to rewrite generate function (e.g. _generate_by_pass_infer_state) - # and pass BatchInferState via model forward - model = self.model - if isinstance(model, LlamaForCausalLM): - model = self.model.model - elif isinstance(model, BloomForCausalLM): - model = self.model.transformer - setattr(model, "infer_state", batch_infer_state) - - outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False) - - # NOTE In future development, we're going to let the scheduler to handle the cache, - # instead of freeing space explicitly at the end of generation - self.cache_manager.free_all() - - return outputs - - # TODO might want to implement the func that generates output tokens by passing BatchInferState - # as an arg into model.forward. - # It requires rewriting model generate and replacing model forward. - @torch.no_grad() - def _generate_by_pass_infer_state( - self, - input_tokens, - max_out_length: int, - generation_config: Optional[GenerationConfig] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - **model_kwargs, - ) -> torch.Tensor: - raise NotImplementedError("generate by passing BatchInferState is not implemented.") - - # might want to use in rewritten generate method: use after model.forward - # BatchInferState is created and kept during generation - # after each iter of model forward, we should update BatchInferState - def _update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: - batch_size = infer_state.batch_size - device = infer_state.start_loc.device - infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device) - infer_state.seq_len += 1 - - @torch.no_grad() - def forward(self, batch_id, is_prefill): - """ - Forward is used in Dynamic Batching Manager - """ - batch = self.cache.pop(batch_id) - if is_prefill: - input_ = torch.tensor(batch.all_input_ids).cuda() - else: - input_ = batch.input_ids.reshape(len(batch), 1) - - batch_args = { - "batch_size": len(batch), - "max_len_in_batch": batch.nopad_max_len_in_batch, - "block_loc": batch.nopad_b_loc, - "start_loc": batch.nopad_b_start_loc, - "seq_len": batch.nopad_b_seq_len, - "cache_manager": batch.cache_manager, - "is_context_stage": is_prefill, - } - - infer_state = BatchInferState(**batch_args) - model = self.model - if isinstance(model, LlamaForCausalLM): - model = self.model.model - elif isinstance(model, BloomForCausalLM): - model = self.model.transformer - - setattr(model, "infer_state", infer_state) - output = self.model.forward(input_ids=input_) - logits = output.logits - # bsz, seq_len, vocab_size - prob_out = torch.softmax( - logits[ - :, - -1, - ], - dim=-1, - ).squeeze(1) - # prob_out: bsz, vocab_size - predict_ids = torch.argmax(prob_out, dim=-1, keepdim=True) - prob_out = torch.log(prob_out).detach().cpu().numpy() - predict_ids = predict_ids.detach().cpu().numpy() - # [ batch_size, 1 ] - - output_dict = {} - new_input_ids = [] - for i, (r, all_input_ids, next_token_id, next_token_logprob) in enumerate( - zip(batch.requests, batch.all_input_ids, predict_ids, prob_out) - ): - next_token_id = int(next_token_id) - next_token_logprob = next_token_logprob[next_token_id] - # all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device="cuda") - all_input_ids.append(next_token_id) - # all_input_ids_tensor = None - new_input_ids.append(next_token_id) - batch.all_input_ids[i] = all_input_ids - batch.input_lengths[i] += 1 - batch.out_token_id_counts[i][next_token_id] += 1 - metadata = { - "id": int(next_token_id), - "logprob": float(next_token_logprob), - } - output_dict[r["request_id"]] = (int(next_token_id), metadata) - - batch.input_ids = torch.tensor(new_input_ids, dtype=torch.long).cuda() - batch.nopad_total_token_num += len(batch) - batch.nopad_max_len_in_batch += 1 # NOTE: we may repalce this - self.cache[batch.batch_id] = batch - return output_dict - - @torch.no_grad() - def _prefill_batch(self, batch_id): - return self.forward(batch_id, is_prefill=True) - - @torch.no_grad() - def _decode_batch(self, batch_id): - return self.forward(batch_id, is_prefill=False) - - # might want to create a sequence pool - # add a single request/sequence/input text at a time and record its length - # In other words, store the actual length of input tokens representing a single input text - # E.g. "Introduce landmarks in Beijing" - # => add request - # => record token length and other necessary information to be used - # => engine hold all these necessary information until `generate` (or other name) is called, - # => put information already recorded in batchinferstate and pass it to model forward - # => clear records in engine - def add_request(): - raise NotImplementedError() diff --git a/colossalai/inference/tensor_parallel/modeling/__init__.py b/colossalai/inference/tensor_parallel/modeling/__init__.py deleted file mode 100644 index 4662368b17b4..000000000000 --- a/colossalai/inference/tensor_parallel/modeling/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .bloom import BloomInferenceForwards -from .chatglm2 import ChatGLM2InferenceForwards -from .llama import LlamaInferenceForwards - -__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards", "ChatGLM2InferenceForwards"] diff --git a/colossalai/inference/tensor_parallel/modeling/_utils.py b/colossalai/inference/tensor_parallel/modeling/_utils.py deleted file mode 100644 index 068b64b4f829..000000000000 --- a/colossalai/inference/tensor_parallel/modeling/_utils.py +++ /dev/null @@ -1,67 +0,0 @@ -""" -Utils for model inference -""" -import os - -import torch - -from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest - - -def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): - """ - This function copies the key and value cache to the memory cache - Args: - layer_id : id of current layer - key_buffer : key cache - value_buffer : value cache - context_mem_index : index of memory cache in kv cache manager - mem_manager : cache manager - """ - copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) - copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) - - -def init_to_get_rotary(self, base=10000, use_elem=False): - """ - This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer - Args: - self : Model that holds the rotary positional embedding - base : calculation arg - use_elem : activated when using chatglm-based models - """ - self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads - if not hasattr(self.config, "rope_scaling"): - rope_scaling_factor = 1.0 - else: - rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 - - if hasattr(self.config, "max_sequence_length"): - max_seq_len = self.config.max_sequence_length - elif hasattr(self.config, "max_position_embeddings"): - max_seq_len = self.config.max_position_embeddings * rope_scaling_factor - else: - max_seq_len = 2048 * rope_scaling_factor - base = float(base) - - # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ - ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None) - - if ntk_alpha is not None: - ntk_alpha = float(ntk_alpha) - assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1" - if ntk_alpha > 1: - print(f"Note: NTK enabled, alpha set to {ntk_alpha}") - max_seq_len *= ntk_alpha - base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula - - n_elem = self.config.head_dim_ - if use_elem: - n_elem //= 2 - - inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) - t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor - freqs = torch.outer(t, inv_freq) - - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py deleted file mode 100644 index 0ad3994b0194..000000000000 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ /dev/null @@ -1,537 +0,0 @@ -import math -import warnings -from typing import Optional, Tuple, Union - -import torch -import torch.distributed as dist -from torch.nn import CrossEntropyLoss -from torch.nn import functional as F -from transformers.models.bloom.modeling_bloom import ( - BaseModelOutputWithPastAndCrossAttentions, - BloomAttention, - BloomBlock, - BloomForCausalLM, - BloomModel, - CausalLMOutputWithCrossAttentions, -) -from transformers.utils import logging - -from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd - -try: - from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_bloom_context_attention_fwd - HAS_LIGHTLLM_KERNEL = True -except: - HAS_LIGHTLLM_KERNEL = False - - -def generate_alibi(n_head, dtype=torch.float16): - """ - This method is adapted from `_generate_alibi` function - in `lightllm/models/bloom/layer_weights/transformer_layer_weight.py` - of the ModelTC/lightllm GitHub repository. - This method is originally the `build_alibi_tensor` function - in `transformers/models/bloom/modeling_bloom.py` - of the huggingface/transformers GitHub repository. - """ - - def get_slopes_power_of_2(n): - start = 2 ** (-(2 ** -(math.log2(n) - 3))) - return [start * start**i for i in range(n)] - - def get_slopes(n): - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2) - slopes_double = get_slopes(2 * closest_power_of_2) - slopes_combined = slopes_power_of_2 + slopes_double[0::2][: n - closest_power_of_2] - return slopes_combined - - slopes = get_slopes(n_head) - return torch.tensor(slopes, dtype=dtype) - - -class BloomInferenceForwards: - """ - This class serves a micro library for bloom inference forwards. - We intend to replace the forward methods for BloomForCausalLM, BloomModel, BloomBlock, and BloomAttention, - as well as prepare_inputs_for_generation method for BloomForCausalLM. - For future improvement, we might want to skip replacing methods for BloomForCausalLM, - and call BloomModel.forward iteratively in TpInferEngine - """ - - @staticmethod - def bloom_model_forward( - self: BloomModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - infer_state: Optional[BatchInferState] = None, - **deprecated_arguments, - ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: - logger = logging.get_logger(__name__) - - if deprecated_arguments.pop("position_ids", False) is not False: - # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` - warnings.warn( - "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" - " passing `position_ids`.", - FutureWarning, - ) - if len(deprecated_arguments) > 0: - raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - # still need to keep past_key_values to fit original forward flow - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape batch_size x num_heads x N x N - # head_mask has shape n_layer x batch x num_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # NOTE determine if BatchInferState is passed in via arg - # if not, get the attr binded to the model - # We might wantto remove setattr later - if infer_state is None: - assert hasattr(self, "infer_state") - infer_state = self.infer_state - - # infer_state.cache_manager = self.cache_manager - if infer_state.is_context_stage: - past_key_values_length = 0 - else: - past_key_values_length = infer_state.max_len_in_batch - 1 - - if use_cache and seq_length != 1: - # prefill stage - infer_state.is_context_stage = True # set prefill stage, notify attention layer - infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) - BatchInferState.init_block_loc( - infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index - ) - else: - infer_state.is_context_stage = False - alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) - if alloc_mem is not None: - infer_state.decode_is_contiguous = True - infer_state.decode_mem_index = alloc_mem[0] - infer_state.decode_mem_start = alloc_mem[1] - infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index - else: - print(f" *** Encountered allocation non-contiguous") - print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}") - infer_state.decode_is_contiguous = False - alloc_mem = infer_state.cache_manager.alloc(batch_size) - infer_state.decode_mem_index = alloc_mem - # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index - - if attention_mask is None: - attention_mask = torch.ones((batch_size, infer_state.max_len_in_batch), device=hidden_states.device) - else: - attention_mask = attention_mask.to(hidden_states.device) - - # NOTE revise: we might want to store a single 1D alibi(length is #heads) in model, - # or store to BatchInferState to prevent re-calculating - # When we have multiple process group (e.g. dp together with tp), we need to pass the pg to here - # alibi = generate_alibi(self.num_heads).contiguous().cuda() - tp_size = dist.get_world_size() - curr_tp_rank = dist.get_rank() - alibi = ( - generate_alibi(self.num_heads * tp_size) - .contiguous()[curr_tp_rank * self.num_heads : (curr_tp_rank + 1) * self.num_heads] - .cuda() - ) - causal_mask = self._prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, - ) - - infer_state.decode_layer_id = 0 - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - # NOTE: currently our KV cache manager does not handle this condition - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - alibi, - causal_mask, - layer_past, - head_mask[i], - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=causal_mask, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - alibi=alibi, - infer_state=infer_state, - ) - - infer_state.decode_layer_id += 1 - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - - # Add last hidden state - hidden_states = self.ln_f(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # update indices of kv cache block - # NOT READY FOR PRIME TIME - # might want to remove this part, instead, better to pass the BatchInferState from model forward, - # and update these information in engine.generate after model foward called - infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.seq_len += 1 - infer_state.max_len_in_batch += 1 - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, # should always be (None, None, ..., None) - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - @staticmethod - def bloom_for_causal_lm_forward( - self: BloomForCausalLM, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - infer_state: Optional[BatchInferState] = None, - **deprecated_arguments, - ): - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ - logging.get_logger(__name__) - - if deprecated_arguments.pop("position_ids", False) is not False: - # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` - warnings.warn( - "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" - " passing `position_ids`.", - FutureWarning, - ) - if len(deprecated_arguments) > 0: - raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = BloomInferenceForwards.bloom_model_forward( - self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - infer_state=infer_state, - ) - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - batch_size, seq_length, vocab_size = shift_logits.shape - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) - ) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def bloom_for_causal_lm_prepare_inputs_for_generation( - self: BloomForCausalLM, - input_ids: torch.LongTensor, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> dict: - # only last token for input_ids if past is not None - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) - - # NOTE we won't use past key values here - # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed - # if past_key_values[0][0].shape[0] == input_ids.shape[0]: - # past_key_values = self._convert_to_bloom_cache(past_key_values) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def bloom_block_forward( - self: BloomBlock, - hidden_states: torch.Tensor, - alibi: torch.Tensor, - attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - infer_state: Optional[BatchInferState] = None, - ): - # hidden_states: [batch_size, seq_length, hidden_size] - - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - - # Layer norm post the self attention. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - - # Self attention. - attn_outputs = self.self_attention( - layernorm_output, - residual, - layer_past=layer_past, - attention_mask=attention_mask, - alibi=alibi, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - infer_state=infer_state, - ) - - attention_output = attn_outputs[0] - - outputs = attn_outputs[1:] - - layernorm_output = self.post_attention_layernorm(attention_output) - - # Get residual - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = attention_output - - # MLP. - output = self.mlp(layernorm_output, residual) - - if use_cache: - outputs = (output,) + outputs - else: - outputs = (output,) + outputs[1:] - - return outputs # hidden_states, present, attentions - - @staticmethod - def bloom_attention_forward( - self: BloomAttention, - hidden_states: torch.Tensor, - residual: torch.Tensor, - alibi: torch.Tensor, - attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - infer_state: Optional[BatchInferState] = None, - ): - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] - - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - batch_size, q_length, H, D_HEAD = query_layer.shape - k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 - v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 - - mem_manager = infer_state.cache_manager - layer_id = infer_state.decode_layer_id - - if infer_state.is_context_stage: - # context process - max_input_len = q_length - b_start_loc = infer_state.start_loc - b_seq_len = infer_state.seq_len[:batch_size] - q = query_layer.reshape(-1, H, D_HEAD) - - copy_kv_cache_to_dest(k, infer_state.context_mem_index, mem_manager.key_buffer[layer_id]) - copy_kv_cache_to_dest(v, infer_state.context_mem_index, mem_manager.value_buffer[layer_id]) - - # output = self.output[:batch_size*q_length, :, :] - output = torch.empty_like(q) - - if HAS_LIGHTLLM_KERNEL: - lightllm_bloom_context_attention_fwd(q, k, v, output, alibi, b_start_loc, b_seq_len, max_input_len) - else: - bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) - - context_layer = output.view(batch_size, q_length, H * D_HEAD) - else: - # query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) - # need shape: batch_size, H, D_HEAD (q_length == 1), input q shape : (batch_size, q_length(1), H, D_HEAD) - assert q_length == 1, "for non-context process, we only support q_length == 1" - q = query_layer.reshape(-1, H, D_HEAD) - - if infer_state.decode_is_contiguous: - # if decode is contiguous, then we copy to key cache and value cache in cache manager directly - cache_k = infer_state.cache_manager.key_buffer[layer_id][ - infer_state.decode_mem_start : infer_state.decode_mem_end, :, : - ] - cache_v = infer_state.cache_manager.value_buffer[layer_id][ - infer_state.decode_mem_start : infer_state.decode_mem_end, :, : - ] - cache_k.copy_(k) - cache_v.copy_(v) - else: - # if decode is not contiguous, use triton kernel to copy key and value cache - # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head] - copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id]) - copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id]) - - b_start_loc = infer_state.start_loc - b_loc = infer_state.block_loc - b_seq_len = infer_state.seq_len - output = torch.empty_like(q) - token_attention_fwd( - q, - mem_manager.key_buffer[layer_id], - mem_manager.value_buffer[layer_id], - output, - b_loc, - b_start_loc, - b_seq_len, - infer_state.max_len_in_batch, - alibi, - ) - - context_layer = output.view(batch_size, q_length, H * D_HEAD) - - # NOTE: always set present as none for now, instead of returning past key value to the next decoding, - # we create the past key value pair from the cache manager - present = None - - # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 - if self.pretraining_tp > 1 and self.slow_but_exact: - slices = self.hidden_size / self.pretraining_tp - output_tensor = torch.zeros_like(context_layer) - for i in range(self.pretraining_tp): - output_tensor = output_tensor + F.linear( - context_layer[:, :, int(i * slices) : int((i + 1) * slices)], - self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], - ) - else: - output_tensor = self.dense(context_layer) - - # dropout is not required here during inference - output_tensor = residual + output_tensor - - outputs = (output_tensor, present) - assert output_attentions is False, "we do not support output_attentions at this time" - - return outputs diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py deleted file mode 100644 index b8fe8eb54855..000000000000 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ /dev/null @@ -1,545 +0,0 @@ -import os -from typing import Optional, Tuple - -import torch -from torch.nn import CrossEntropyLoss -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast - -from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( - ChatGLMForConditionalGeneration, - ChatGLMModel, - GLMBlock, - GLMTransformer, - SelfAttention, - split_tensor_along_last_dim, -) - -from ._utils import copy_kv_to_mem_cache - -try: - from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd - from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd as lightllm_llama2_context_attention_fwd, - ) - - HAS_LIGHTLLM_KERNEL = True -except: - print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") - HAS_LIGHTLLM_KERNEL = False - - -# This func is same as Llama model init_to_get_rotary, we should move them into _utils.py -def _init_to_get_rotary(self, base=10000): - self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads - if not hasattr(self.config, "rope_scaling"): - rope_scaling_factor = 1.0 - else: - rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 - if hasattr(self.config, "max_sequence_length"): - max_seq_len = self.config.max_sequence_length - elif hasattr(self.config, "max_position_embeddings"): - max_seq_len = self.config.max_position_embeddings * rope_scaling_factor - else: - max_seq_len = 2048 * rope_scaling_factor - base = float(base) - - # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ - try: - ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1)) - assert ntk_alpha >= 1 - if ntk_alpha > 1: - print(f"Note: NTK enabled, alpha set to {ntk_alpha}") - max_seq_len *= ntk_alpha - base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula - except: - pass - n_elem = self.config.head_dim_ // 2 - inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) - t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor - freqs = torch.outer(t, inv_freq) - - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() - return - - -def get_masks(self, input_ids, past_length, padding_mask=None): - batch_size, seq_length = input_ids.shape - full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) - full_attention_mask.tril_() - if past_length: - full_attention_mask = torch.cat( - ( - torch.ones(batch_size, seq_length, past_length, device=input_ids.device), - full_attention_mask, - ), - dim=-1, - ) - - if padding_mask is not None: - full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) - if not past_length and padding_mask is not None: - full_attention_mask -= padding_mask.unsqueeze(-1) - 1 - full_attention_mask = (full_attention_mask < 0.5).bool() - full_attention_mask.unsqueeze_(1) - return full_attention_mask - - -class ChatGLM2InferenceForwards: - """ - This class holds forwards for Chatglm2 inference. - We intend to replace the forward methods for ChatGLMModel, ChatGLMEecoderLayer, and ChatGLMAttention. - """ - - @staticmethod - def chatglm_for_conditional_generation_forward( - self: ChatGLMForConditionalGeneration, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - ): - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - infer_state = self.infer_state - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if infer_state.is_context_stage: - past_key_values_length = 0 - else: - past_key_values_length = infer_state.max_len_in_batch - 1 - - seq_length_with_past = seq_length + past_key_values_length - - # prefill stage at first - if use_cache and seq_length != 1: - infer_state.is_context_stage = True - infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) - infer_state.init_block_loc( - infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index - ) - else: - infer_state.is_context_stage = False - alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) - if alloc_mem is not None: - infer_state.decode_is_contiguous = True - infer_state.decode_mem_index = alloc_mem[0] - infer_state.decode_mem_start = alloc_mem[1] - infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index - else: - print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) - infer_state.decode_is_contiguous = False - alloc_mem = infer_state.cache_manager.alloc(batch_size) - infer_state.decode_mem_index = alloc_mem - # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index - - # related to rotary embedding - if infer_state.is_context_stage: - infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1 - ) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1 - ) - else: - seq_len = infer_state.seq_len - infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - infer_state=infer_state, - ) - - hidden_states = transformer_outputs[0] - if return_last_logit: - hidden_states = hidden_states[-1:] - lm_logits = self.transformer.output_layer(hidden_states) - lm_logits = lm_logits.transpose(0, 1).contiguous() - - loss = None - if labels is not None: - lm_logits = lm_logits.to(torch.float32) - - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - lm_logits = lm_logits.to(hidden_states.dtype) - loss = loss.to(hidden_states.dtype) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def chatglm_model_forward( - self: ChatGLMModel, - input_ids, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - infer_state: BatchInferState = None, - ): - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - batch_size, seq_length = input_ids.shape - - if inputs_embeds is None: - inputs_embeds = self.embedding(input_ids) - - if self.pre_seq_len is not None: - if past_key_values is None: - past_key_values = self.get_prompt( - batch_size=batch_size, - device=input_ids.device, - dtype=inputs_embeds.dtype, - ) - if attention_mask is not None: - attention_mask = torch.cat( - [ - attention_mask.new_ones((batch_size, self.pre_seq_len)), - attention_mask, - ], - dim=-1, - ) - if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = get_masks( - self, input_ids, infer_state.cache_manager.past_key_values_length, padding_mask=attention_mask - ) - - # Run encoder. - hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( - inputs_embeds, - full_attention_mask, - kv_caches=past_key_values, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - infer_state=infer_state, - ) - - # update indices - # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.seq_len += 1 - infer_state.max_len_in_batch += 1 - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - ] - if v is not None - ) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - @staticmethod - def chatglm_encoder_forward( - self: GLMTransformer, - hidden_states, - attention_mask, - kv_caches=None, - use_cache: Optional[bool] = True, - output_hidden_states: Optional[bool] = False, - infer_state: Optional[BatchInferState] = None, - ): - hidden_states = hidden_states.transpose(0, 1).contiguous() - if not kv_caches: - kv_caches = [None for _ in range(self.num_layers)] - presents = () if use_cache else None - all_self_attentions = None - all_hidden_states = () if output_hidden_states else None - - infer_state.decode_layer_id = 0 - for index in range(self.num_layers): - layer = self.layers[index] - - layer_ret = layer( - hidden_states, - attention_mask, - kv_cache=kv_caches[index], - use_cache=use_cache, - infer_state=infer_state, - ) - - infer_state.decode_layer_id += 1 - - hidden_states, kv_cache = layer_ret - if use_cache: - presents = presents + (kv_cache,) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # Final layer norm. - hidden_states = hidden_states.transpose(0, 1).contiguous() - - if self.post_layer_norm: - hidden_states = self.final_layernorm(hidden_states) - - return hidden_states, presents, all_hidden_states, all_self_attentions - - @staticmethod - def chatglm_glmblock_forward( - self: GLMBlock, - hidden_states, - attention_mask, - kv_cache=None, - use_cache=True, - infer_state: Optional[BatchInferState] = None, - ): - # hidden_states: [s, b, h] - - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - # Self attention. - attention_output, kv_cache = self.self_attention( - layernorm_output, - attention_mask, - kv_cache=kv_cache, - use_cache=use_cache, - infer_state=infer_state, - ) - # Residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) - layernorm_input = residual + layernorm_input - # Layer norm post the self attention. - layernorm_output = self.post_attention_layernorm(layernorm_input) - # MLP. - mlp_output = self.mlp(layernorm_output) - - # Second residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = layernorm_input - - output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) - output = residual + output - return output, kv_cache - - @staticmethod - def chatglm_flash_attn_kvcache_forward( - self: SelfAttention, - hidden_states, - attention_mask, - kv_cache=None, - use_cache=True, - infer_state: Optional[BatchInferState] = None, - ): - assert use_cache is True, "use_cache should be set to True using this chatglm attention" - # hidden_states: original :[sq, b, h] --> this [b, sq, h] - batch_size = hidden_states.shape[0] - hidden_size = hidden_states.shape[-1] - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - mixed_x_layer = self.query_key_value(hidden_states) - if self.multi_query_attention: - (query_layer, key_layer, value_layer) = mixed_x_layer.split( - [ - self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - ], - dim=-1, - ) - query_layer = query_layer.view( - query_layer.size()[:-1] - + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ) - ) - key_layer = key_layer.view( - key_layer.size()[:-1] - + ( - self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head, - ) - ) - value_layer = value_layer.view( - value_layer.size()[:-1] - + ( - self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head, - ) - ) - - else: - new_tensor_shape = mixed_x_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - cos, sin = infer_state.position_cos, infer_state.position_sin - - chatglm2_rotary_emb_fwd( - query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin - ) - if self.multi_query_attention: - chatglm2_rotary_emb_fwd( - key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head), - cos, - sin, - ) - else: - chatglm2_rotary_emb_fwd( - key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), - cos, - sin, - ) - - # reshape q k v to [bsz*sql, num_heads, head_dim] 2*1 ,32/2 ,128 - query_layer = query_layer.reshape( - -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head - ) - key_layer = key_layer.reshape( - -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head - ) - value_layer = value_layer.reshape( - -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head - ) - - if infer_state.is_context_stage: - # first token generation: - # copy key and value calculated in current step to memory manager - copy_kv_to_mem_cache( - infer_state.decode_layer_id, - key_layer, - value_layer, - infer_state.context_mem_index, - infer_state.cache_manager, - ) - attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size)) - - # NOTE: no bug in context attn fwd (del it ) - lightllm_llama2_context_attention_fwd( - query_layer, - key_layer, - value_layer, - attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - ) - - else: - if infer_state.decode_is_contiguous: - # if decode is contiguous, then we copy to key cache and value cache in cache manager directly - cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ - infer_state.decode_mem_start : infer_state.decode_mem_end, :, : - ] - cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ - infer_state.decode_mem_start : infer_state.decode_mem_end, :, : - ] - cache_k.copy_(key_layer) - cache_v.copy_(value_layer) - else: - # if decode is not contiguous, use triton kernel to copy key and value cache - # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head - copy_kv_to_mem_cache( - infer_state.decode_layer_id, - key_layer, - value_layer, - infer_state.decode_mem_index, - infer_state.cache_manager, - ) - - # second token and follows - attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size)) - cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ - : infer_state.decode_mem_end, :, : - ] - cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ - : infer_state.decode_mem_end, :, : - ] - - # ================================== - # core attention computation is replaced by triton kernel - # ================================== - Llama2TokenAttentionForwards.token_attn( - query_layer, - cache_k, - cache_v, - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - infer_state.other_kv_index, - ) - - # print('after attention',torch.isnan(attn_output).any()) - - # ================= - # Output:[b,sq, h] - # ================= - output = self.dense(attn_output).reshape(batch_size, -1, hidden_size) - - return output, kv_cache diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py deleted file mode 100644 index 62c2aad3c055..000000000000 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ /dev/null @@ -1,423 +0,0 @@ -import math -from typing import List, Optional, Tuple - -import torch -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel - -from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd -from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards - -from ._utils import copy_kv_to_mem_cache - -try: - from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd as lightllm_llama2_context_attention_fwd, - ) - from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd as lightllm_context_attention_fwd, - ) - from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd - - HAS_LIGHTLLM_KERNEL = True -except: - print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") - HAS_LIGHTLLM_KERNEL = False - -try: - from flash_attn import flash_attn_with_kvcache - - HAS_FLASH_KERNEL = True -except: - HAS_FLASH_KERNEL = False - print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention") - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def llama_triton_context_attention( - query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1 -): - if num_key_value_groups == 1: - if HAS_LIGHTLLM_KERNEL is False: - llama_context_attn_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, - infer_state.max_len_in_batch, - ) - else: - lightllm_context_attention_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, - infer_state.max_len_in_batch, - ) - else: - assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model" - lightllm_llama2_context_attention_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, - infer_state.max_len_in_batch, - ) - - -def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1): - assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models" - if num_key_value_groups == 1: - token_attention_fwd( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, - infer_state.max_len_in_batch, - ) - else: - Llama2TokenAttentionForwards.token_attn( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, - infer_state.max_len_in_batch, - infer_state.other_kv_index, - ) - - -class LlamaInferenceForwards: - """ - This class holds forwards for llama inference. - We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM. - """ - - @staticmethod - def llama_model_forward( - self: LlamaModel, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - infer_state = self.infer_state - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - use_cache = use_cache if use_cache is not None else self.config.use_cache - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - if infer_state.is_context_stage: - past_key_values_length = 0 - else: - past_key_values_length = infer_state.max_len_in_batch - 1 - - # NOTE: differentiate with prefill stage - # block_loc require different value-assigning method for two different stage - if use_cache and seq_length != 1: - # NOTE assume prefill stage - # allocate memory block - infer_state.is_context_stage = True # set prefill stage, notify attention layer - infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) - infer_state.init_block_loc( - infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index - ) - else: - infer_state.is_context_stage = False - alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) - if alloc_mem is not None: - infer_state.decode_is_contiguous = True - infer_state.decode_mem_index = alloc_mem[0] - infer_state.decode_mem_start = alloc_mem[1] - infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index - else: - print(f" *** Encountered allocation non-contiguous") - print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}") - infer_state.decode_is_contiguous = False - alloc_mem = infer_state.cache_manager.alloc(batch_size) - infer_state.decode_mem_index = alloc_mem - # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.repeat(batch_size, 1) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if infer_state.is_context_stage: - infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1 - ) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1 - ) - - else: - seq_len = infer_state.seq_len - infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=inputs_embeds.device - ) - - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - infer_state.decode_layer_id = 0 - for idx, decoder_layer in enumerate(self.layers): - past_key_value = past_key_values[idx] if past_key_values is not None else None - # NOTE: modify here for passing args to decoder layer - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - infer_state=infer_state, - ) - infer_state.decode_layer_id += 1 - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - hidden_states = self.norm(hidden_states) - next_cache = next_decoder_cache if use_cache else None - - # update indices - # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.seq_len += 1 - infer_state.max_len_in_batch += 1 - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - @staticmethod - def llama_decoder_layer_forward( - self: LlamaDecoderLayer, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - infer_state: Optional[BatchInferState] = None, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - infer_state=infer_state, - ) - - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - @staticmethod - def llama_flash_attn_kvcache_forward( - self: LlamaAttention, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - infer_state: Optional[BatchInferState] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - assert use_cache is True, "use_cache should be set to True using this llama attention" - - bsz, q_len, _ = hidden_states.size() - - # NOTE might think about better way to handle transposed k and v - # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head] - # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] - - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) - - # NOTE might want to revise - # need some way to record the length of past key values cache - # since we won't return past_key_value_cache right now - - cos, sin = infer_state.position_cos, infer_state.position_sin - - llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) - llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin) - - query_states = query_states.reshape(-1, self.num_heads, self.head_dim) - key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) - value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim) - - if infer_state.is_context_stage: - # first token generation - # copy key and value calculated in current step to memory manager - copy_kv_to_mem_cache( - infer_state.decode_layer_id, - key_states, - value_states, - infer_state.context_mem_index, - infer_state.cache_manager, - ) - attn_output = torch.empty_like(query_states) - - llama_triton_context_attention( - query_states, - key_states, - value_states, - attn_output, - infer_state, - num_key_value_groups=self.num_key_value_groups, - ) - else: - if infer_state.decode_is_contiguous: - # if decode is contiguous, then we copy to key cache and value cache in cache manager directly - cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ - infer_state.decode_mem_start : infer_state.decode_mem_end, :, : - ] - cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ - infer_state.decode_mem_start : infer_state.decode_mem_end, :, : - ] - cache_k.copy_(key_states) - cache_v.copy_(value_states) - else: - # if decode is not contiguous, use triton kernel to copy key and value cache - # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head - copy_kv_to_mem_cache( - infer_state.decode_layer_id, - key_states, - value_states, - infer_state.decode_mem_index, - infer_state.cache_manager, - ) - - if HAS_LIGHTLLM_KERNEL: - attn_output = torch.empty_like(query_states) - llama_triton_token_attention( - query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups - ) - else: - self.num_heads // self.num_key_value_heads - cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] - cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] - - query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) - copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) - copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) - - attn_output = flash_attn_with_kvcache( - q=query_states, - k_cache=copy_cache_k, - v_cache=copy_cache_v, - softmax_scale=1 / math.sqrt(self.head_dim), - causal=True, - ) - - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - # return past_key_value as None - return attn_output, None, None diff --git a/colossalai/inference/tensor_parallel/policies/__init__.py b/colossalai/inference/tensor_parallel/policies/__init__.py deleted file mode 100644 index 776c4e850565..000000000000 --- a/colossalai/inference/tensor_parallel/policies/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .bloom import BloomModelInferPolicy -from .chatglm2 import ChatGLM2InferPolicy -from .llama import LlamaModelInferPolicy - -__all__ = ["BloomModelInferPolicy", "LlamaModelInferPolicy", "ChatGLM2InferPolicy"] diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py deleted file mode 100644 index 3d6df2097000..000000000000 --- a/colossalai/inference/tensor_parallel/policies/bloom.py +++ /dev/null @@ -1,99 +0,0 @@ -from functools import partial - -import torch -from torch.nn import LayerNorm - -import colossalai.shardformer.layer as col_nn -from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn -from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription -from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy - -from ..modeling.bloom import BloomInferenceForwards - -try: - from colossalai.kernel.triton import layer_norm - - HAS_TRITON_NORM = True -except: - print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton") - HAS_TRITON_NORM = False - - -def get_triton_layernorm_forward(): - if HAS_TRITON_NORM: - - def _triton_layernorm_forward(self: LayerNorm, hidden_states: torch.Tensor): - return layer_norm(hidden_states, self.weight.data, self.bias, self.eps) - - return _triton_layernorm_forward - else: - return None - - -class BloomModelInferPolicy(BloomForCausalLMPolicy): - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel - - policy = super().module_policy() - if self.shard_config.inference_gptq: - from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear - policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ - "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attention.query_key_value", - target_module=ColCaiQuantLinear, - kwargs={'split_num': 3}), - SubModuleReplacementDescription( - suffix="self_attention.dense", - target_module=RowCaiQuantLinear, - kwargs={'split_num': 1}), - SubModuleReplacementDescription( - suffix="self_attention.attention_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dense_h_to_4h", - target_module=ColCaiQuantLinear, - kwargs={'split_num': 1}), - SubModuleReplacementDescription( - suffix="mlp.dense_4h_to_h", - target_module=RowCaiQuantLinear, - kwargs={'split_num': 1}), - ]) - # NOTE set inference mode to shard config - self.shard_config._infer() - - method_replacement = { - "forward": BloomInferenceForwards.bloom_for_causal_lm_forward, - "prepare_inputs_for_generation": BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation, - } - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=BloomForCausalLM - ) - - method_replacement = {"forward": BloomInferenceForwards.bloom_model_forward} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel) - - method_replacement = {"forward": BloomInferenceForwards.bloom_block_forward} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock) - - method_replacement = {"forward": BloomInferenceForwards.bloom_attention_forward} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=BloomAttention - ) - - if HAS_TRITON_NORM: - infer_method = get_triton_layernorm_forward() - method_replacement = {"forward": partial(infer_method)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LayerNorm - ) - - return policy diff --git a/colossalai/inference/tensor_parallel/policies/chatglm2.py b/colossalai/inference/tensor_parallel/policies/chatglm2.py deleted file mode 100644 index 60dc511f5e96..000000000000 --- a/colossalai/inference/tensor_parallel/policies/chatglm2.py +++ /dev/null @@ -1,77 +0,0 @@ -from functools import partial - -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( - ChatGLMForConditionalGeneration, - ChatGLMModel, - GLMBlock, - GLMTransformer, - SelfAttention, -) - -# import colossalai -from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy - -from ..modeling._utils import init_to_get_rotary -from ..modeling.chatglm2 import ChatGLM2InferenceForwards - -try: - HAS_TRITON_RMSNORM = True -except: - print("you should install triton from https://github.com/openai/triton") - HAS_TRITON_RMSNORM = False - - -class ChatGLM2InferPolicy(ChatGLMModelPolicy): - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - policy = super().module_policy() - self.shard_config._infer() - - model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward - method_replacement = {"forward": model_infer_forward} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel) - - encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward - method_replacement = {"forward": encoder_infer_forward} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=GLMTransformer - ) - - encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward - method_replacement = {"forward": encoder_layer_infer_forward} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock) - - attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward - method_replacement = {"forward": attn_infer_forward} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=SelfAttention - ) - if self.shard_config.enable_tensor_parallelism: - policy[GLMBlock].attribute_replacement["self_attention.num_multi_query_groups_per_partition"] = ( - self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size - ) - # for rmsnorm and others, we need to check the shape - return policy - - def postprocess(self): - init_to_get_rotary(self.model) - return self.model - - -class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy): - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - policy = super().module_policy() - model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward - method_replacement = {"forward": partial(model_infer_forward)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=ChatGLMForConditionalGeneration - ) - return policy - - def postprocess(self): - return super().postprocess() diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py deleted file mode 100644 index d6c072c747b7..000000000000 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ /dev/null @@ -1,119 +0,0 @@ -from functools import partial - -import torch -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm - -from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription - -# import colossalai -from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy - -from ..modeling._utils import init_to_get_rotary -from ..modeling.llama import LlamaInferenceForwards - -try: - from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward - HAS_TRITON_RMSNORM = True -except: - print("you should install triton from https://github.com/openai/triton") - HAS_TRITON_RMSNORM = False - - -def get_triton_rmsnorm_forward(): - if HAS_TRITON_RMSNORM: - def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) - - return _triton_rmsnorm_forward - else: - return None - - -class LlamaModelInferPolicy(LlamaForCausalLMPolicy): - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - policy = super().module_policy() - - if self.shard_config.inference_gptq: - from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear - - decoder_attribute_replacement = { - "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - } - policy[LlamaDecoderLayer] = ModulePolicyDescription( - attribute_replacement=decoder_attribute_replacement, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attn.q_proj", - target_module=ColCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="self_attn.k_proj", - target_module=ColCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="self_attn.v_proj", - target_module=ColCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="self_attn.o_proj", - target_module=RowCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="mlp.gate_proj", - target_module=ColCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="mlp.up_proj", - target_module=ColCaiQuantLinear, - kwargs={"split_num": 1}, - ), - SubModuleReplacementDescription( - suffix="mlp.down_proj", - target_module=RowCaiQuantLinear, - kwargs={"split_num": 1}, - ), - ], - ) - - self.shard_config._infer() - - infer_forward = LlamaInferenceForwards.llama_model_forward - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) - - infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaDecoderLayer - ) - - infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaAttention - ) - - infer_forward = None - if HAS_TRITON_RMSNORM: - infer_forward = get_triton_rmsnorm_forward() - - if infer_forward is not None: - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaRMSNorm - ) - - return policy - - def postprocess(self): - init_to_get_rotary(self.model.model) - return self.model diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 3ec0f97a747a..0b0bb11f4661 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -7,7 +7,7 @@ from torch.nn import Module from torch.utils._pytree import tree_map -from colossalai.inference.pipeline.microbatch_manager import MicroBatchManager, Status +from colossalai.inference.hybridengine.microbatch_manager import MicroBatchManager, Status from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils.cuda import get_current_device diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/examples/inference/benchmark.py similarity index 100% rename from colossalai/inference/pipeline/benchmark/benchmark.py rename to examples/inference/benchmark.py diff --git a/colossalai/inference/pipeline/benchmark/run.sh b/examples/inference/run.sh similarity index 100% rename from colossalai/inference/pipeline/benchmark/run.sh rename to examples/inference/run.sh diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py deleted file mode 100644 index d4366758d6a3..000000000000 --- a/tests/test_infer/test_bloom_infer.py +++ /dev/null @@ -1,70 +0,0 @@ -import pytest -import torch -from packaging import version -from transformers import BloomForCausalLM -from transformers.models.bloom.configuration_bloom import BloomConfig - -import colossalai -from colossalai.inference.tensor_parallel import TPInferEngine -from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn - -try: - import lightllm - HAS_LIGHTLLM_KERNEL = True -except: - HAS_LIGHTLLM_KERNEL = False - -TP_SIZE = 2 -MAX_BATCH_SIZE = 4 -MAX_INPUT_LEN = 16 -MAX_OUTPUT_LEN = 32 - -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") - - -@parameterize( - "test_config", - [ - { - "tp_size": TP_SIZE, - } - ], -) -def run(test_config): - bloom_config = BloomConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) - model = BloomForCausalLM(bloom_config) - model = model.half() - - shard_config = ShardConfig( - enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True - ) - infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) - - input_tokens = { - "input_ids": torch.randint(1, 1000, (MAX_BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), - "attention_mask": torch.ones((MAX_BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), - } - outputs = infer_engine.generate(input_tokens, **generate_kwargs) - - assert outputs is not None - - -def check_bloom(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run() - - -@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_bloom_infer(): - spawn(check_bloom, TP_SIZE) - - -if __name__ == "__main__": - test_bloom_infer() diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py deleted file mode 100644 index a2ec35dcdb8a..000000000000 --- a/tests/test_infer/test_chatglm2_infer.py +++ /dev/null @@ -1,83 +0,0 @@ -import os - -import pytest -import torch -from packaging import version - -import colossalai -from colossalai.inference.tensor_parallel.engine import TPInferEngine -from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig -from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn - -try: - import lightllm # noqa - - HAS_LIGHTLLM_KERNEL = True -except: - HAS_LIGHTLLM_KERNEL = False - -os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" -TPSIZE = 2 -BATCH_SIZE = 8 -MAX_INPUT_LEN = 12 -MAX_OUTPUT_LEN = 100 -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") - - -@parameterize( - "test_config", - [ - { - "tp_size": TPSIZE, - } - ], -) -def run_chatglm2_test(test_config): - chatglm_config = ChatGLMConfig( - num_layers=2, - vocab_size=1200, - use_cache=True, - multi_query_attention=True, - multi_query_group_num=2, - num_attention_heads=8, - hidden_size=1024, - ) - model = ChatGLMForConditionalGeneration(chatglm_config) - model = model.half() - - shard_config = ShardConfig( - enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True - ) - infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) - - input_tokens = { - "input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), - "attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), - } - outputs = infer_engine.generate(input_tokens, **generate_kwargs) - assert outputs is not None - - -def check_chatglm2(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_chatglm2_test() - - -@pytest.mark.skipif( - not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, - reason="kv-cache manager engine requires cuda version to be higher than 11.5", -) -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_chatglm2(): - spawn(check_chatglm2, TPSIZE) - - -if __name__ == "__main__": - test_chatglm2() diff --git a/tests/test_infer/test_dynamic_batching/config.yaml b/tests/test_infer/test_dynamic_batching/config.yaml deleted file mode 100644 index 0ac778a3c7b3..000000000000 --- a/tests/test_infer/test_dynamic_batching/config.yaml +++ /dev/null @@ -1,14 +0,0 @@ -engine_config: - model: MODEL_PATH - tensor_parallel_size: 1 - max_batch_size: 2 - max_input_len: 1024 - max_output_len: 512 -# config for app router deployment -# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig. -router_config: - max_total_token_num: 4096 - batch_max_tokens: 4096 - disable_log_stats: False - log_stats_interval: 10 - model: MODEL_PATH diff --git a/tests/test_infer/test_dynamic_batching/test_async_engine.py b/tests/test_infer/test_dynamic_batching/test_async_engine.py deleted file mode 100644 index 512aa7430983..000000000000 --- a/tests/test_infer/test_dynamic_batching/test_async_engine.py +++ /dev/null @@ -1,61 +0,0 @@ -import asyncio -import os -import uuid - -import pytest - -import colossalai -from colossalai.inference.async_engine import Async_Engine -from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig -from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn - -PATH = "config.yaml" - - -def run_async_engine(path: str): - if not os.path.exists(path): - return - - config = RayInitConfig.from_yaml_path(path) - engine_config = config.engine_config_data - model = engine_config.model - if model is None or not os.path.exists(model): - return - - prompt = "Introduce some landmarks in London.\n The Tower of London is a historic castle on the north bank of the River Thames in central London. It was founded towards the end of 10" - sampling_params = SamplingParams() - asyncio.run(asy_for_loop_test(config, prompt, sampling_params)) - - -async def get_result(engine, prompt, sampling_params): - request_id = str(uuid.uuid4().hex) - results = engine.generate(request_id, prompt, sampling_params) - async for result in results: - # print(result) - assert result is not None - - -async def asy_for_loop_test(config, prompt, sampling_params): - router_config = config.router_config_data - engine_config = config.engine_config_data - engine = Async_Engine(router_config=router_config, engine_config=engine_config) - for i in range(10): - print("in for loop", i) - await get_result(engine, prompt, sampling_params) - - -def check_async_engine(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_async_engine(PATH) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_async_engine(): - spawn(check_async_engine, 1) - - -if __name__ == "__main__": - test_async_engine() diff --git a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py deleted file mode 100644 index 78df0d304096..000000000000 --- a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py +++ /dev/null @@ -1,95 +0,0 @@ -import pytest -from transformers import LlamaForCausalLM -from transformers.models.llama.configuration_llama import LlamaConfig - -import colossalai -from colossalai.inference.dynamic_batching.io_struct import Req -from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -from colossalai.inference.manager import DynamicBatchManager -from colossalai.inference.tensor_parallel import TPInferEngine -from colossalai.shardformer import ShardConfig -from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn - -TP_SIZE = 1 -BATCH_SIZE = 2 -MAX_INPUT_LEN = 48 -MAX_OUTPUT_LEN = 256 - - -def run(): - sampling_params = SamplingParams() - - req1 = Req(0, [1], sampling_params) - req2 = Req(1, [2], sampling_params) - req3 = Req(2, [3], sampling_params) - # req 1-3 are initiliazed as token forward requests - req4 = Req(3, [10, 10, 10, 9, 1], sampling_params) - waiting_list = [] - waiting_list.append(req1) - waiting_list.append(req2) - waiting_list.append(req3) - - # init model and tp engine - llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) - model = LlamaForCausalLM(llama_config) - model = model.half() - - shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) - infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - - dynamic_batch_manager = DynamicBatchManager( - tp_engine=infer_engine, - max_total_token_num=640, - batch_max_tokens=608, - eos_id=0, - log_stats=False, - log_stats_interval=10, - waiting_req_list=waiting_list, - model="llama", - ) - before_add = len(dynamic_batch_manager.req_queue) - - # test add req function - dynamic_batch_manager.add_req(req4.request_id, req4.prompt_ids, req4.sample_params) - assert len(dynamic_batch_manager.req_queue.waiting_req_list) == before_add + 1 - - # test abort function - dynamic_batch_manager.abort(req4.request_id) - assert dynamic_batch_manager.req_queue.waiting_req_list[-1].aborted == True - - # test filter batch function, loop_for_fwd, _step, _init_batch and _prefill/_decode batch are tested - batch = dynamic_batch_manager.req_queue.generate_new_batch() - assert len(batch) == 2 - - dynamic_batch_manager._init_batch(batch) - assert dynamic_batch_manager.engine.cache[batch.batch_id] is not None - - batch.reqs[0].has_generate_finished = True - # filter one finished - batch.filter_finished() - dynamic_batch_manager._filter_batch(batch) - assert len(dynamic_batch_manager.engine.cache) == 1 - - # test merge batch - new_batch = dynamic_batch_manager.req_queue.generate_new_batch(batch) - assert len(new_batch) == 1 - dynamic_batch_manager._init_batch(new_batch) - dynamic_batch_manager._merge_batch(batch, new_batch) - - assert len(dynamic_batch_manager.engine.cache[batch.batch_id]) == 2 - - -def check_dynamic_batching_manager(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_dynamic_batching_manager(): - spawn(check_dynamic_batching_manager, 1) - - -if __name__ == "__main__": - test_dynamic_batching_manager() diff --git a/tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py b/tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py deleted file mode 100644 index 9925a80b6e77..000000000000 --- a/tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py +++ /dev/null @@ -1,84 +0,0 @@ -from dataclasses import dataclass - -import pytest -import torch -from packaging import version -from transformers import LlamaForCausalLM -from transformers.models.llama.configuration_llama import LlamaConfig - -import colossalai -from colossalai.inference.dynamic_batching.io_struct import Req -from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -from colossalai.inference.manager import start_dynamic_batching -from colossalai.inference.tensor_parallel import TPInferEngine -from colossalai.shardformer import ShardConfig -from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn - -TP_SIZE = 1 -MAX_BATCH_SIZE = 2 -MAX_INPUT_LEN = 5 -MAX_OUTPUT_LEN = 16 -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") - - -@dataclass -class args: - max_total_token_num: int - batch_max_tokens: int - model: str - eos_id: int - disable_log_stats: bool - log_stats_interval: int - - -def run(): - arg = args( - max_total_token_num=42, - model="llama", - batch_max_tokens=42, - eos_id=0, - disable_log_stats=False, - log_stats_interval=10, - ) - sampling_params = SamplingParams() - - req1 = Req(0, [0, 0, 10, 6, 8], sampling_params) - req2 = Req(1, [10, 10, 10, 10, 10], sampling_params) - req3 = Req(2, [0, 0, 10, 10, 10], sampling_params) - req4 = Req(3, [0, 0, 10, 10, 10], sampling_params) - - waiting_list = [] - waiting_list.append(req1) - waiting_list.append(req2) - waiting_list.append(req3) - waiting_list.append(req4) - - llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=30000, hidden_size=1024) - model = LlamaForCausalLM(llama_config) - model = model.half() - - shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) - - infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - batch_manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) - - ans_gen = batch_manager.generate(request_id=5, prompts="hello", sampling_params=sampling_params) - for result in ans_gen: - assert result is not None - - -def check_dynamic_forward(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run() - - -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_dynamic_batching(): - spawn(check_dynamic_forward, TP_SIZE) - - -if __name__ == "__main__": - test_dynamic_batching() diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py deleted file mode 100644 index a840407d5867..000000000000 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ /dev/null @@ -1,66 +0,0 @@ -import asyncio -import os -import uuid - -import pytest - -import colossalai -from colossalai.inference.dynamic_batching.ray_dist_init import Driver -from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig -from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn - -PATH = "config.yaml" - - -def run_ray_dist(path: str): - if not os.path.exists(path): - return - config = RayInitConfig.from_yaml_path(path) - router_config = config.router_config_data - engine_config = config.engine_config_data - model = engine_config.model - if model is None or not os.path.exists(model): - return - driver = Driver(router_config=router_config, engine_config=engine_config) - prompt = "Introduce some landmarks in Beijing" - - request_id = str(uuid.uuid4().hex) - sampling_params = SamplingParams() - print("sampling_params: ", sampling_params) - - async def get_result(request_id, prompt, sampling_params): - return await driver.async_generate(request_id, prompt, sampling_params) - - for test_async in [True, False]: - if test_async: - print("test_async: ", test_async) - result = asyncio.run(get_result(request_id, prompt, sampling_params)) - assert result is not None - print("result: ", result) - else: - print("test_async: ", test_async) - result = driver.generate(request_id, prompt, sampling_params) - assert result is not None - print("result: ", result) - - is_running = None - is_running = driver.is_running() - assert is_running is not None - print("is_running: ", is_running) - - -def check_ray_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_ray_dist(PATH) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_ray_dist(): - spawn(check_ray_dist, 1) - - -if __name__ == "__main__": - test_ray_dist() diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py deleted file mode 100644 index f24160820e71..000000000000 --- a/tests/test_infer/test_infer_engine.py +++ /dev/null @@ -1,102 +0,0 @@ -from itertools import accumulate - -import pytest -import torch -from packaging import version -from transformers import BloomConfig, BloomForCausalLM -from transformers.tokenization_utils_base import BatchEncoding - -import colossalai -from colossalai.inference.tensor_parallel import TPInferEngine -from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn - -TP_SIZE = 2 -MAX_BATCH_SIZE = 4 -MAX_INPUT_LEN = 16 -MAX_OUTPUT_LEN = 8 - -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") - - -@parameterize( - "test_config", - [ - { - "tp_size": TP_SIZE, - } - ], -) -def run(test_config): - model_config = BloomConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4) - model = BloomForCausalLM(model_config) - model = model.half() - model.to(torch.cuda.current_device()) - - # 1. check TPInferEngine init and model optimization - shard_config = ShardConfig( - enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True - ) - infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - - assert infer_engine.cache_manager is not None - assert infer_engine.tp_size == TP_SIZE - assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE - - # 2. check data preparation - input_ids_list = [ - [80540, 15473, 3331, 11970, 90472, 361, 61335], - [80540, 15473, 3331, 11970], - [80540, 15473, 3331, 11970], - [80540, 15473], - ] - batch_size = len(input_ids_list) - max_seq_len = max(len(li) for li in input_ids_list) - attention_mask = [[0] * max_seq_len for _ in range(batch_size)] - for i, li in enumerate(input_ids_list): - attention_mask[i][max_seq_len - len(li) :] = [1 for _ in range(len(li))] - data = dict(input_ids=input_ids_list, attention_mask=attention_mask) - inputs_batch_encoding = BatchEncoding(data=data) - seq_lengths = [len(li) for li in input_ids_list] - start_loc = list(accumulate([0] + seq_lengths[:-1])) - seq_lengths = torch.tensor(seq_lengths, dtype=torch.int32) - start_loc = torch.tensor(start_loc, dtype=torch.int32) - # input token id list as inputs - batch_state_out1 = infer_engine.prepare_batch_state(inputs_batch_encoding) - # BatchEncoding as inputs - batch_state_out2 = infer_engine.prepare_batch_state(input_ids_list) - - assert batch_state_out1.batch_size == batch_state_out2.batch_size == batch_size - assert torch.equal(batch_state_out1.seq_len, batch_state_out2.seq_len) - - # The following tests are discarded for now, and will be reused after all features are added - # assert torch.equal(batch_state_out1.seq_len.to(seq_lengths.device), seq_lengths) - # assert torch.equal(batch_state_out2.seq_len.to(seq_lengths.device), seq_lengths) - # assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc) - # assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc) - - # 3. check optimized model generate - input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN)) - generate_kwargs = dict(do_sample=False) - infer_engine.generate(input_ids, **generate_kwargs) - - torch.cuda.empty_cache() - - -def check_engine(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run() - - -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_engine(): - spawn(check_engine, TP_SIZE) - - -if __name__ == "__main__": - test_engine() diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py index f3e2cdf1e18f..9b4ed937ebd7 100644 --- a/tests/test_infer/test_kvcache_manager.py +++ b/tests/test_infer/test_kvcache_manager.py @@ -4,7 +4,7 @@ import torch from packaging import version -from colossalai.inference.tensor_parallel import MemoryManager +from colossalai.inference.kvcache_manager import MemoryManager from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, spawn diff --git a/tests/test_infer/test_llama2_infer.py b/tests/test_infer/test_llama2_infer.py deleted file mode 100644 index 13e7a61826ab..000000000000 --- a/tests/test_infer/test_llama2_infer.py +++ /dev/null @@ -1,75 +0,0 @@ -import os - -import pytest -import torch -from packaging import version -from transformers import LlamaForCausalLM -from transformers.models.llama.configuration_llama import LlamaConfig - -import colossalai -from colossalai.inference.tensor_parallel.engine import TPInferEngine -from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn - -try: - import lightllm - HAS_LIGHTLLM_KERNEL = True -except: - HAS_LIGHTLLM_KERNEL = False - -os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" -TPSIZE = 2 -BATCH_SIZE = 8 -MAX_INPUT_LEN = 12 -MAX_OUTPUT_LEN = 100 - -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") - - -@parameterize( - "test_config", - [ - { - "tp_size": TPSIZE, - } - ], -) -def run_llama_test(test_config): - llama_config = LlamaConfig( - num_hidden_layers=2, num_key_value_heads=8, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024 - ) - model = LlamaForCausalLM(llama_config) - model = model.half() - - shard_config = ShardConfig( - enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True - ) - infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) - - input_tokens = { - "input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), - "attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), - } - outputs = infer_engine.generate(input_tokens, **generate_kwargs) - - assert outputs is not None - - -def check_llama(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_llama_test() - - -@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_llama(): - spawn(check_llama, TPSIZE) - - -if __name__ == "__main__": - test_llama() diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py deleted file mode 100644 index a4f54d197065..000000000000 --- a/tests/test_infer/test_llama_infer.py +++ /dev/null @@ -1,73 +0,0 @@ -import os - -import pytest -import torch -from packaging import version -from transformers import LlamaForCausalLM -from transformers.models.llama.configuration_llama import LlamaConfig - -import colossalai -from colossalai.inference.tensor_parallel.engine import TPInferEngine -from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn - -try: - import lightllm - HAS_LIGHTLLM_KERNEL = True -except: - HAS_LIGHTLLM_KERNEL = False - -os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" -TPSIZE = 2 -BATCH_SIZE = 8 -MAX_INPUT_LEN = 12 -MAX_OUTPUT_LEN = 100 - -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") - - -@parameterize( - "test_config", - [ - { - "tp_size": TPSIZE, - } - ], -) -def run_llama_test(test_config): - llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) - model = LlamaForCausalLM(llama_config) - model = model.half() - - shard_config = ShardConfig( - enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True - ) - infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) - - input_tokens = { - "input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), - "attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), - } - outputs = infer_engine.generate(input_tokens, **generate_kwargs) - - assert outputs is not None - - -def check_llama(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_llama_test() - - -@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_llama(): - spawn(check_llama, TPSIZE) - - -if __name__ == "__main__": - test_llama() From eaeb2a0777ffbea1247fa96ecbce4cfe27e1463f Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 8 Nov 2023 13:44:50 +0800 Subject: [PATCH 2/5] fix quant model --- colossalai/inference/quant/smoothquant/models/base_model.py | 3 +-- colossalai/inference/quant/smoothquant/models/llama.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py index 9fe3241cf5c3..763c6eb44ee4 100644 --- a/colossalai/inference/quant/smoothquant/models/base_model.py +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -20,8 +20,7 @@ from transformers.utils.generic import ContextManagers from transformers.utils.hub import PushToHubMixin, cached_file -from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager +from colossalai.inference.kvcache_manager.batch_infer_state import BatchInferState, MemoryManager try: import accelerate diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index 9d4bd9f7794b..60aa2be74a2d 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -21,7 +21,7 @@ ) from transformers.utils import add_start_docstrings_to_model_forward -from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.inference.kvcache_manager.batch_infer_state import BatchInferState from colossalai.kernel.triton import ( copy_kv_cache_to_dest, int8_rotary_embedding_fwd, From 24ecf77b766a238cf9fe9cdfdde152661bed1a9d Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 9 Nov 2023 14:20:22 +0800 Subject: [PATCH 3/5] fix test import bug --- tests/test_infer/test_hybrid_bloom.py | 8 +++++--- tests/test_infer/test_hybrid_llama.py | 11 ++++++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/test_infer/test_hybrid_bloom.py b/tests/test_infer/test_hybrid_bloom.py index 14b745982686..336e70b78004 100644 --- a/tests/test_infer/test_hybrid_bloom.py +++ b/tests/test_infer/test_hybrid_bloom.py @@ -1,3 +1,5 @@ +import importlib.util + import pytest import torch import torch.distributed as dist @@ -9,9 +11,9 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") -try: - HAS_LIGHTLLM_KERNEL = True -except: +HAS_LIGHTLLM_KERNEL = True + +if importlib.util.find_spec("lightllm") is None: HAS_LIGHTLLM_KERNEL = False diff --git a/tests/test_infer/test_hybrid_llama.py b/tests/test_infer/test_hybrid_llama.py index d917ae2d8ac2..6846c372de56 100644 --- a/tests/test_infer/test_hybrid_llama.py +++ b/tests/test_infer/test_hybrid_llama.py @@ -1,3 +1,5 @@ +import importlib.util + import pytest import torch import torch.distributed as dist @@ -9,9 +11,12 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") -try: - HAS_LIGHTLLM_KERNEL = True -except: + +import importlib.util + +HAS_LIGHTLLM_KERNEL = True + +if importlib.util.find_spec("lightllm") is None: HAS_LIGHTLLM_KERNEL = False From 4a1150dd78bc3edd51eb928da47f5c90ef8f3e16 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Fri, 10 Nov 2023 11:38:19 +0800 Subject: [PATCH 4/5] mv original inference legacy --- colossalai/legacy/inference/README.md | 143 +++ colossalai/legacy/inference/__init__.py | 4 + colossalai/legacy/inference/async_engine.py | 133 +++ colossalai/legacy/inference/async_manager.py | 151 ++++ .../inference/dynamic_batching/__init__.py | 0 .../dynamic_batching/get_tokenizer.py | 40 + .../inference/dynamic_batching/infer_batch.py | 346 ++++++++ .../inference/dynamic_batching/io_struct.py | 166 ++++ .../dynamic_batching/ray_dist_init.py | 152 ++++ .../dynamic_batching/ray_init_config.py | 58 ++ .../inference/dynamic_batching/req_queue.py | 73 ++ .../dynamic_batching/sampling_params.py | 83 ++ .../inference/dynamic_batching/stats.py | 45 + .../legacy/inference/hybridengine/__init__.py | 3 + .../legacy/inference/hybridengine/engine.py | 170 ++++ .../hybridengine/modeling/__init__.py | 3 + .../inference/hybridengine/modeling/_utils.py | 67 ++ .../inference/hybridengine/modeling/llama.py | 489 ++++++++++ .../hybridengine/polices/__init__.py | 3 + .../inference/hybridengine/polices/llama.py | 142 +++ colossalai/legacy/inference/manager.py | 296 +++++++ .../legacy/inference/pipeline/README.md | 83 ++ .../legacy/inference/pipeline/__init__.py | 3 + .../inference/pipeline/benchmark/benchmark.py | 134 +++ .../inference/pipeline/benchmark/run.sh | 50 ++ .../inference/pipeline/microbatch_manager.py | 249 ++++++ .../legacy/inference/quant/gptq/__init__.py | 4 + .../inference/quant/gptq/cai_gptq/__init__.py | 14 + .../quant/gptq/cai_gptq/cai_quant_linear.py | 354 ++++++++ .../inference/quant/gptq/cai_gptq/gptq_op.py | 58 ++ .../inference/quant/smoothquant/__init__.py | 0 .../quant/smoothquant/models/__init__.py | 12 + .../quant/smoothquant/models/base_model.py | 487 ++++++++++ .../quant/smoothquant/models/linear.py | 179 ++++ .../quant/smoothquant/models/llama.py | 838 ++++++++++++++++++ .../inference/tensor_parallel/__init__.py | 4 + .../tensor_parallel/batch_infer_state.py | 118 +++ .../inference/tensor_parallel/engine.py | 480 ++++++++++ .../tensor_parallel/kvcache_manager.py | 106 +++ .../tensor_parallel/modeling/__init__.py | 5 + .../tensor_parallel/modeling/_utils.py | 67 ++ .../tensor_parallel/modeling/bloom.py | 540 +++++++++++ .../tensor_parallel/modeling/chatglm2.py | 545 ++++++++++++ .../tensor_parallel/modeling/llama.py | 423 +++++++++ .../tensor_parallel/policies/__init__.py | 5 + .../tensor_parallel/policies/bloom.py | 101 +++ .../tensor_parallel/policies/chatglm2.py | 77 ++ .../tensor_parallel/policies/llama.py | 121 +++ .../triton/test_token_attn_fwd.py | 12 +- 49 files changed, 7635 insertions(+), 1 deletion(-) create mode 100644 colossalai/legacy/inference/README.md create mode 100644 colossalai/legacy/inference/__init__.py create mode 100644 colossalai/legacy/inference/async_engine.py create mode 100644 colossalai/legacy/inference/async_manager.py create mode 100644 colossalai/legacy/inference/dynamic_batching/__init__.py create mode 100644 colossalai/legacy/inference/dynamic_batching/get_tokenizer.py create mode 100644 colossalai/legacy/inference/dynamic_batching/infer_batch.py create mode 100644 colossalai/legacy/inference/dynamic_batching/io_struct.py create mode 100644 colossalai/legacy/inference/dynamic_batching/ray_dist_init.py create mode 100644 colossalai/legacy/inference/dynamic_batching/ray_init_config.py create mode 100644 colossalai/legacy/inference/dynamic_batching/req_queue.py create mode 100644 colossalai/legacy/inference/dynamic_batching/sampling_params.py create mode 100644 colossalai/legacy/inference/dynamic_batching/stats.py create mode 100644 colossalai/legacy/inference/hybridengine/__init__.py create mode 100644 colossalai/legacy/inference/hybridengine/engine.py create mode 100644 colossalai/legacy/inference/hybridengine/modeling/__init__.py create mode 100644 colossalai/legacy/inference/hybridengine/modeling/_utils.py create mode 100644 colossalai/legacy/inference/hybridengine/modeling/llama.py create mode 100644 colossalai/legacy/inference/hybridengine/polices/__init__.py create mode 100644 colossalai/legacy/inference/hybridengine/polices/llama.py create mode 100644 colossalai/legacy/inference/manager.py create mode 100644 colossalai/legacy/inference/pipeline/README.md create mode 100644 colossalai/legacy/inference/pipeline/__init__.py create mode 100644 colossalai/legacy/inference/pipeline/benchmark/benchmark.py create mode 100644 colossalai/legacy/inference/pipeline/benchmark/run.sh create mode 100644 colossalai/legacy/inference/pipeline/microbatch_manager.py create mode 100644 colossalai/legacy/inference/quant/gptq/__init__.py create mode 100644 colossalai/legacy/inference/quant/gptq/cai_gptq/__init__.py create mode 100644 colossalai/legacy/inference/quant/gptq/cai_gptq/cai_quant_linear.py create mode 100644 colossalai/legacy/inference/quant/gptq/cai_gptq/gptq_op.py create mode 100644 colossalai/legacy/inference/quant/smoothquant/__init__.py create mode 100644 colossalai/legacy/inference/quant/smoothquant/models/__init__.py create mode 100644 colossalai/legacy/inference/quant/smoothquant/models/base_model.py create mode 100644 colossalai/legacy/inference/quant/smoothquant/models/linear.py create mode 100644 colossalai/legacy/inference/quant/smoothquant/models/llama.py create mode 100644 colossalai/legacy/inference/tensor_parallel/__init__.py create mode 100644 colossalai/legacy/inference/tensor_parallel/batch_infer_state.py create mode 100644 colossalai/legacy/inference/tensor_parallel/engine.py create mode 100644 colossalai/legacy/inference/tensor_parallel/kvcache_manager.py create mode 100644 colossalai/legacy/inference/tensor_parallel/modeling/__init__.py create mode 100644 colossalai/legacy/inference/tensor_parallel/modeling/_utils.py create mode 100644 colossalai/legacy/inference/tensor_parallel/modeling/bloom.py create mode 100644 colossalai/legacy/inference/tensor_parallel/modeling/chatglm2.py create mode 100644 colossalai/legacy/inference/tensor_parallel/modeling/llama.py create mode 100644 colossalai/legacy/inference/tensor_parallel/policies/__init__.py create mode 100644 colossalai/legacy/inference/tensor_parallel/policies/bloom.py create mode 100644 colossalai/legacy/inference/tensor_parallel/policies/chatglm2.py create mode 100644 colossalai/legacy/inference/tensor_parallel/policies/llama.py diff --git a/colossalai/legacy/inference/README.md b/colossalai/legacy/inference/README.md new file mode 100644 index 000000000000..f466f46c1629 --- /dev/null +++ b/colossalai/legacy/inference/README.md @@ -0,0 +1,143 @@ +# 🚀 Colossal-Inference + +## Table of contents + +## Introduction + +`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including LightLLM, TGI, vLLM, FasterTransformer and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users. + +## Design + +Colossal Inference is composed of two main components: + +1. High performance kernels and ops: which are inspired from existing libraries and modified correspondingly. +2. Efficient memory management mechanism:which includes the key-value cache manager, allowing for zero memory waste during inference. + 1. `cache manager`: serves as a memory manager to help manage the key-value cache, it integrates functions such as memory allocation, indexing and release. + 2. `batch_infer_info`: holds all essential elements of a batch inference, which is updated every batch. +3. High-level inference engine combined with `Shardformer`: it allows our inference framework to easily invoke and utilize various parallel methods. + 1. `engine.TPInferEngine`: it is a high level interface that integrates with shardformer, especially for multi-card (tensor parallel) inference: + 2. `modeling.llama.LlamaInferenceForwards`: contains the `forward` methods for llama inference. (in this case : llama) + 3. `policies.llama.LlamaModelInferPolicy` : contains the policies for `llama` models, which is used to call `shardformer` and segmentate the model forward in tensor parallelism way. + +## Pipeline of inference: + +In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes. + +![Colossal-Inference](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Colossal-inference.png) + +## Roadmap of our implementation + +- [x] Design cache manager and batch infer state +- [x] Design TpInference engine to integrates with `Shardformer` +- [x] Register corresponding high-performance `kernel` and `ops` +- [x] Design policies and forwards (e.g. `Llama` and `Bloom`) + - [x] policy + - [x] context forward + - [x] token forward + - [x] support flash-decoding +- [ ] Replace the kernels with `faster-transformer` in token-forward stage +- [ ] Support all models + - [x] Llama + - [x] Llama-2 + - [x] Bloom + - [x] Chatglm2 +- [ ] Benchmarking for all models + +## Get started + +### Installation + +```bash +pip install -e . +``` + +### Requirements + +dependencies + +```bash +pytorch= 1.13.1 (gpu) +cuda>= 11.6 +transformers= 4.30.2 +triton +# for install flash-attention +flash-attention + +# install lightllm since we depend on lightllm triton kernels +git clone https://github.com/ModelTC/lightllm +cd lightllm +git checkout 28c1267cfca536b7b4f28e921e03de735b003039 +pip3 install -e . + +# also, install xformers from source: +pip install ninja +# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types +pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers + +``` + +### Docker + +You can use docker run to use docker container to set-up environment + +``` +# env: python==3.8, cuda 11.6, pytorch == 1.13.1 triton==2.0.0.dev20221202, vllm kernels support, flash-attention-2 kernels support +docker pull hpcaitech/colossalai-inference:v2 +docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash + +# enter into docker container +cd /path/to/CollossalAI +pip install -e . + +# install lightllm +git clone https://github.com/ModelTC/lightllm +cd lightllm +git checkout 28c1267cfca536b7b4f28e921e03de735b003039 +pip3 install -e . + +# install xformers from source +pip install ninja +# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types +pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers +``` + +### Dive into fast-inference! + +example files are in + +```bash +cd colossalai.examples +python xx +``` + +## Performance + +### environment: + +We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `colossal-inference` and original `hugging-face torch fp16`. + +For various models, experiments were conducted using multiple batch sizes under the consistent model configuration of `7 billion(7b)` parameters, `1024` input length, and 128 output length. The obtained results are as follows (due to time constraints, the evaluation has currently been performed solely on the `A100` single GPU performance; multi-GPU performance will be addressed in the future): + +### Single GPU Performance: + +Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to further optimize the performance of LLM models. Please stay tuned. + +#### Llama + +| batch_size | 8 | 16 | 32 | +| :---------------------: | :----: | :----: | :----: | +| hugging-face torch fp16 | 199.12 | 246.56 | 278.4 | +| colossal-inference | 326.4 | 582.72 | 816.64 | + +![llama](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Infer-llama7b.png) + +### Bloom + +| batch_size | 8 | 16 | 32 | +| :---------------------: | :----: | :----: | :----: | +| hugging-face torch fp16 | 189.68 | 226.66 | 249.61 | +| colossal-inference | 323.28 | 538.52 | 611.64 | + +![bloom](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Infer-bloom7b.png) + +The results of more models are coming soon! diff --git a/colossalai/legacy/inference/__init__.py b/colossalai/legacy/inference/__init__.py new file mode 100644 index 000000000000..d5a988cfc6f0 --- /dev/null +++ b/colossalai/legacy/inference/__init__.py @@ -0,0 +1,4 @@ +from .hybridengine import CaiInferEngine +from .hybridengine.polices import LlamaModelInferPolicy + +__all__ = ["CaiInferEngine", "LlamaModelInferPolicy"] diff --git a/colossalai/legacy/inference/async_engine.py b/colossalai/legacy/inference/async_engine.py new file mode 100644 index 000000000000..d0890ba3e9fc --- /dev/null +++ b/colossalai/legacy/inference/async_engine.py @@ -0,0 +1,133 @@ +import asyncio + +from colossalai.inference.dynamic_batching.ray_dist_init import Driver + +from .dynamic_batching.io_struct import RequestOutput +from .dynamic_batching.sampling_params import SamplingParams + + +class RequestTracker: + """ + A class for trace down all the requests, abstraction for async + """ + + def __init__(self) -> None: + self._requests: asyncio.Queue[str] = asyncio.Queue() + self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue() + self.new_requests_event = None + + def __contains__(self, item): + return item in self._requests + + def init_event(self): + self.new_requests_event = asyncio.Event() + + def add_request(self, request_id: str): + """Add a request to be sent to the engine on the next background + loop iteration.""" + self._requests.put_nowait(request_id) + self.new_requests_event.set() # NOTE: we may find a better way to clear this event + + def add_stop(self): + """ + Add a StopIteration flag to stop async generator. + """ + self._finished_requests.put_nowait(StopIteration) + self.new_requests_event.clear() + + def process_request_output(self, request_output: RequestOutput) -> None: + """Process a request output from the engine.""" + self._finished_requests.put_nowait(request_output) + + async def wait_for_new_requests(self): + await self.new_requests_event.wait() + + def __aiter__(self): + return self + + async def __anext__(self) -> RequestOutput: + result = await self._finished_requests.get() + # print("result of ", result) + if result is StopIteration: + raise StopAsyncIteration + return result + + +class Async_Engine: + + """ + Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager + Background loop: inference reqs in waiting list (Listen) + Request Tracker: manage incoming requests and restore finished ones + Generate: exposed func for add new input and return finished ones + """ + + def __init__( + self, + router_config, + engine_config, + start_engine_loop: bool = True, + ) -> None: + self.driver = Driver(router_config=router_config, engine_config=engine_config) + self.background_loop = None + self.start_engine_loop = start_engine_loop + self._request_tracker = RequestTracker() + + def _step(self): + """ + Logic for handling requests + """ + request_outputs = self.driver.step() + if request_outputs is not None: + for request_output in request_outputs: + self._request_tracker.process_request_output(request_output) + self._request_tracker.add_stop() + + def abort_request(self, request_id: str): + self.driver.abort(request_id) + + def _has_requests_in_progress(self): + return self.driver.is_running() + + async def run_loop_fwd(self): + has_requests_in_progress = self._has_requests_in_progress() + while True: + if not has_requests_in_progress: + await self._request_tracker.wait_for_new_requests() + self._step() + await asyncio.sleep(0) + + @property + def is_running(self): + return self.background_loop is not None and not self.background_loop.done() + + def start_background_loop(self): + if self.is_running: + raise RuntimeError("Background loop is already running.") + + self._request_tracker.init_event() + + self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd()) + self.background_loop = asyncio.shield(self.background_loop_unshielded) + + async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams): + self.driver.add_input(request_id, prompt, sampling_params) + self._request_tracker.add_request(request_id) + + async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): + """ + The only exposed func, adding new request and return a async generator that yields the existing results. + """ + try: + if not self.is_running: + self.start_background_loop() + + await self.add_request(request_id, prompt, sampling_params) + + async for request_output in self._request_tracker: + yield request_output + + except (Exception, asyncio.CancelledError) as e: + # If there is an exception or coroutine is cancelled, abort the request. + self.abort_request(request_id) + raise e diff --git a/colossalai/legacy/inference/async_manager.py b/colossalai/legacy/inference/async_manager.py new file mode 100644 index 000000000000..60440a792f1c --- /dev/null +++ b/colossalai/legacy/inference/async_manager.py @@ -0,0 +1,151 @@ +from typing import List + +from .dynamic_batching.io_struct import Batch, Req, RequestOutput +from .manager import DynamicBatchManager +from .tensor_parallel import TPInferEngine + + +class Async_DynamicBatchManager(DynamicBatchManager): + def __init__( + self, + tp_engine: TPInferEngine, + max_total_token_num: int, + batch_max_tokens: int, + model: str, + tokenizer=None, + eos_id=None, + log_stats=True, + log_stats_interval=10, + running_batch: Batch = None, + waiting_req_list: List = [], + ): + """ + Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager + max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) + batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests + running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine + eos_id : The end token of a seq + model: the model weight dir path, the app will load config, weights and tokenizer from this dir + log_stats : whether to log stats + log_stats_interval : log stats interval + running_batch : running batch + waiting_req_list : list of waiting requests, initialized before dynamic batch manager + """ + super().__init__( + tp_engine, + max_total_token_num, + batch_max_tokens, + model, + tokenizer, + eos_id, + log_stats, + log_stats_interval, + running_batch, + waiting_req_list, + ) + + def _step(self): + """ + Logic for handling requests + """ + has_new_finished = False + if self.running_batch is None: + new_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_batch is not None: + self.stats_tool.count_prompt_tokens(new_batch) + self.running_batch = new_batch + has_new_finished, outputs = self._prefill_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens = 0 + + else: + if self.has_wait_tokens < self.max_wait_tokens: + self.stats_tool.count_output_tokens(self.running_batch) + has_new_finished, outputs = self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + + else: + new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_mini_batch is not None: + self.stats_tool.count_prompt_tokens(new_mini_batch) + has_new_finished, outputs = self._prefill_batch(new_mini_batch) + if not new_mini_batch.is_clear(): + self._merge_batch(self.running_batch, new_mini_batch) + self.running_batch.merge(new_mini_batch) + self.has_wait_tokens = 0 + + else: + self.stats_tool.count_output_tokens(self.running_batch) + has_new_finished, outputs = self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + + if has_new_finished: + return outputs + return None + + def _prefill_batch(self, batch): + """ + For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. + """ + self._init_batch(batch) + + # TODO: figure out if cache and batch id is needed + ans = self.engine._prefill_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) + outputs = self._handle_finish_req(batch, has_new_finished_req) + return has_new_finished_req, outputs + # delete finished reqs + + def _decode_batch(self, batch: Batch): + """ + Decoding process + """ + ans = self.engine._decode_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) + outputs = self._handle_finish_req(batch, has_new_finished_req) + return has_new_finished_req, outputs + + def _handle_finish_req(self, batch: Batch, has_new_finished_req): + if has_new_finished_req: + finished_reqs = batch.filter_finished() + if batch.is_clear(): + self._remove_batch(batch) + else: + self._filter_batch(batch) + return self._output_process(finished_reqs) + return None + + def _output_process(self, finished_reqs: List[Req]): + """ + Process the output of a batch. + """ + outputs = [] + for req in finished_reqs: + output = self.tokenizer.decode(req.output_ids) + outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output)) + return outputs + + +def start_dynamic_batching(args, tp_engine, waiting_req_list): + try: + batch_manager = Async_DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + model=args.model, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + except Exception: + raise Exception + + return batch_manager diff --git a/colossalai/legacy/inference/dynamic_batching/__init__.py b/colossalai/legacy/inference/dynamic_batching/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/legacy/inference/dynamic_batching/get_tokenizer.py b/colossalai/legacy/inference/dynamic_batching/get_tokenizer.py new file mode 100644 index 000000000000..94aa3f24393f --- /dev/null +++ b/colossalai/legacy/inference/dynamic_batching/get_tokenizer.py @@ -0,0 +1,40 @@ +""" +Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue. + +license: MIT, see LICENSE for more details. +""" + +from transformers import AutoTokenizer + +_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" + + +def get_tokenizer( + tokenizer=None, + tokenizer_name: str = "", + trust_remote_code: bool = False, + use_fast: bool = True, +): + if tokenizer is not None: + tokenizer = tokenizer + else: + if "llama" in tokenizer_name.lower() and use_fast == True: + print( + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai." + ) + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) + except TypeError: + use_fast = False + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) + return tokenizer diff --git a/colossalai/legacy/inference/dynamic_batching/infer_batch.py b/colossalai/legacy/inference/dynamic_batching/infer_batch.py new file mode 100644 index 000000000000..112784c15f84 --- /dev/null +++ b/colossalai/legacy/inference/dynamic_batching/infer_batch.py @@ -0,0 +1,346 @@ +# Adapted from https://github.com/ModelTC/lightllm + +import collections +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import numpy as np +import torch + +from colossalai.inference.tensor_parallel import MemoryManager + + +# make batch infer state an attr of InferBatch +class InferSamplingParams: + def __init__( + self, + do_sample: bool = False, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + vocab_size: int = -1, + ) -> None: + self.do_sample = do_sample + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + if self.top_k == -1: + self.top_k = vocab_size + return + + +@dataclass +class InferBatch: + batch_id: int + requests: List + requests_idx_mapping: Dict[int, int] + + input_ids: torch.Tensor + + all_input_ids: List[List[int]] + input_lengths: List[int] + + out_token_id_counts: List + sampling_param_list: List[InferSamplingParams] + + nopad_total_token_num: int + nopad_max_len_in_batch: int + nopad_b_loc: torch.Tensor + nopad_b_start_loc: torch.Tensor + nopad_b_seq_len: torch.Tensor + cache_manager: MemoryManager + max_total_len: int + + @classmethod + @torch.no_grad() + def init_batch( + cls, + batch_id, + requests, + dtype: torch.dtype, + device: torch.device, + cache_manager: MemoryManager, + vocab_size: int, + max_total_len: int, + ) -> "InferBatch": + input_lengths = [] + all_input_ids = [] + requests_idx_mapping = {} + + out_token_id_counts = [] + sampling_param_list = [] + + nopad_total_token_num = 0 + nopad_max_len_in_batch = 0 + nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda") + # to avoid memory leak , we pre-allocate 12 more space for each batch. + nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda") + for i, r in enumerate(requests): + # request id -> idx in list mapping + requests_idx_mapping[r["request_id"]] = i + + tokenized_input = r["input_id"] + + input_length = len(tokenized_input) + input_lengths.append(input_length) + all_input_ids.append(tokenized_input) + out_token_id_counts.append(collections.defaultdict(int)) + + # postprocessor + sampling_param = r["sampling_param"] + sampling_param["vocab_size"] = vocab_size + sampling_param_list.append(InferSamplingParams(**sampling_param)) + + nopad_total_token_num += input_length + nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length) + + nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda") + nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] + + if len(requests) > 1: + input_ids = np.concatenate(all_input_ids, dtype=np.int64) + else: + input_ids = all_input_ids[0] + + # Create tensors on device + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + + return cls( + batch_id=batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=out_token_id_counts, + sampling_param_list=sampling_param_list, + cache_manager=cache_manager, + max_total_len=max_total_len, + ) + + @torch.no_grad() + def free_self(self) -> None: + """ + Free the memory of the InferBatch itself + """ + remove_index = [] + for idx in range(len(self)): + remove_index.append( + self.nopad_b_loc[ + idx, + (self.nopad_max_len_in_batch - 1) + - (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1), + ] + ) + remove_index = torch.cat(remove_index, dim=-1) + self.cache_manager.free(remove_index) + + @torch.no_grad() + def filter(self, request_ids: List[int]) -> "InferBatch": + """ + Filter finished batch and return a new InferBatch with left ones. + """ + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + requests_idx_mapping = {} + indices = [] + requests = [] + all_input_ids = [] + input_lengths = [] + nopad_total_token_num = 0 + nopad_max_len_in_batch = 0 + nopad_b_loc = torch.empty((len(request_ids), self.max_total_len + 12), dtype=torch.long, device="cuda") + nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda") + nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda") + + left_idx = [] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + left_idx.append(idx) + + left_idx_set = set(left_idx) + remove_index = [] + for idx in range(len(self)): + if idx not in left_idx_set: + remove_index.append( + self.nopad_b_loc[ + idx, + (self.nopad_max_len_in_batch - 1) + - (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1), + ] + ) + remove_index = torch.cat(remove_index, dim=-1) + self.cache_manager.free(remove_index) + + nopad_max_len_in_batch = 0 + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + indices.append(idx) + + nopad_b_seq_len[:] = self.nopad_b_seq_len[indices] + nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item() + nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] + nopad_total_token_num = torch.sum(nopad_b_seq_len).item() + + nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[ + indices, + (self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1) : (self.nopad_max_len_in_batch - 1), + ] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + requests.append(self.requests[idx]) + all_input_ids.append(self.all_input_ids[idx]) + input_lengths.append(self.input_lengths[idx]) + + input_ids = self.input_ids[indices] + + return InferBatch( + batch_id=self.batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices], + sampling_param_list=[self.sampling_param_list[_i] for _i in indices], + cache_manager=self.cache_manager, + max_total_len=self.max_total_len, + ) + + @classmethod + @torch.no_grad() + def merge(cls, batch1, batch2) -> "InferBatch": + """ + Return megerd new InferBatch + """ + requests = batch1.requests + batch2.requests + requests_idx_mapping = {} + new_batch_size = len(batch1) + len(batch2) + + input_ids = batch1.input_ids.new_empty(new_batch_size) + all_input_ids = [] + input_lengths = [] + out_token_id_counts = [] + sampling_param_list = [] + + cumulative_batch_size = 0 + nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num + nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2.nopad_max_len_in_batch) + max_total_len = max(batch1.max_total_len, batch2.max_total_len) + nopad_b_loc = torch.empty((new_batch_size, batch1.max_total_len + 12), dtype=torch.long, device="cuda") + nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") + nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") + nopad_start_loc_len_temp = 0 + batches = [batch1, batch2] + for i, batch in enumerate(batches): + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + cumulative_batch_size + start_index = cumulative_batch_size + end_index = cumulative_batch_size + len(batch) + input_ids[start_index:end_index] = batch.input_ids + nopad_b_seq_len[start_index:end_index] = batch.nopad_b_seq_len + nopad_b_start_loc[start_index:end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp + nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1] + nopad_b_loc[ + start_index:end_index, + nopad_max_len_in_batch - batch.nopad_max_len_in_batch : nopad_max_len_in_batch - 1, + ] = batch.nopad_b_loc[:, : batch.nopad_max_len_in_batch - 1] + + all_input_ids.extend(batch.all_input_ids) + + input_lengths.extend(batch.input_lengths) + out_token_id_counts.extend(batch.out_token_id_counts) + sampling_param_list.extend(batch.sampling_param_list) + # Update + cumulative_batch_size += len(batch) + + nopad_b_loc[:, nopad_max_len_in_batch - 1] = ( + nopad_total_token_num - new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device="cuda") + ) + return InferBatch( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=out_token_id_counts, + sampling_param_list=sampling_param_list, + cache_manager=batches[0].cache_manager, + max_total_len=max_total_len, + ) + + def __len__(self): + return len(self.requests) + + def get_post_sample_tensors(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + presence_penalties: List[float] = [] + frequency_penalties: List[float] = [] + temperatures: List[float] = [] + top_ps: List[float] = [] + top_ks: List[int] = [] + p_token_ids: List[int] = [] + p_token_counts: List[int] = [] + p_seq_len: List[int] = [ + 0, + ] + p_max_len_in_batch: int = 0 + for i, id_to_count in enumerate(self.out_token_id_counts): + sample_param = self.sampling_param_list[i] + presence_penalties.append(sample_param.presence_penalty) + frequency_penalties.append(sample_param.frequency_penalty) + temperatures.append(sample_param.temperature) + top_ps.append(sample_param.top_p) + top_ks.append(sample_param.top_k) + + for token_id, count in id_to_count.items(): + p_token_ids.append(token_id) + p_token_counts.append(count) + p_seq_len.append(len(id_to_count)) + p_max_len_in_batch = max(p_max_len_in_batch, len(id_to_count)) + + presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda") + frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda") + temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda") + top_ps = torch.tensor(top_ps, dtype=torch.float, device="cuda") + top_ks = torch.tensor(top_ks, dtype=torch.int32, device="cuda") + p_token_ids = torch.tensor(p_token_ids, dtype=torch.int32, device="cuda") + p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda") + p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda") + p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32) + return ( + presence_penalties, + frequency_penalties, + temperatures, + top_ps, + top_ks, + p_token_ids, + p_token_counts, + p_cumsum_seq_len, + p_max_len_in_batch, + ) diff --git a/colossalai/legacy/inference/dynamic_batching/io_struct.py b/colossalai/legacy/inference/dynamic_batching/io_struct.py new file mode 100644 index 000000000000..fc5ecfe5796b --- /dev/null +++ b/colossalai/legacy/inference/dynamic_batching/io_struct.py @@ -0,0 +1,166 @@ +# Adapted from https://github.com/ModelTC/lightllm + +from typing import Dict, List, Tuple + +from .sampling_params import SamplingParams + + +class Req: + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""): + self.request_id = request_id + self.prompt_ids = prompt_ids + self.input_len = len(prompt_ids) + self.max_output_len = sample_params.max_new_tokens + self.sample_params = sample_params + self.output_ids = [] + self.output_metadata_list = [] + self.has_generate_finished = False + self.aborted = False + self.prompts = prompts + + def to_rpc_obj(self): + return { + "request_id": self.request_id, + "input_id": self.prompt_ids, + "output_len": self.max_output_len, + "sampling_param": self.sample_params.to_dict(), + } + + def stop_sequences_matched(self): + # should we add stpp sequences to the sample params? + if self.sample_params.stop_sequences is not None: + for stop_token_ids in self.sample_params.stop_sequences: + stop_len = len(stop_token_ids) + if ( + stop_len > 0 + and len(self.output_ids) >= stop_len + and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)) + ): + return True + return False + + def __repr__(self): + return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, " + + +class Batch: + def __init__(self, batch_id, reqs: List[Req]): + self.batch_id = batch_id + self.reqs = reqs + self.id_to_reqs = {req.request_id: req for req in reqs} + + def input_tokens(self): + batch_input_tokens = 0 + for req in self.reqs: + batch_input_tokens += req.input_len + return batch_input_tokens + + def calcu_max_tokens(self): + tokens = 0 + for req in self.reqs: + tokens += req.input_len + req.max_output_len + return tokens + + def calcu_used_tokens(self): + tokens = 0 + for req in self.reqs: + tokens += req.input_len + len(req.output_ids) + return tokens + + def mark_finished_req(self, eos_id, engine_max_output_len): + has_new_finish = False + for req in self.reqs: + if req.stop_sequences_matched(): + req.has_generate_finished = True + has_new_finish = True + if len(req.output_ids) >= engine_max_output_len: + req.has_generate_finished = True + has_new_finish = True + if req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False: + req.has_generate_finished = True + has_new_finish = True + if len(req.output_ids) >= req.max_output_len or req.aborted: + req.has_generate_finished = True + has_new_finish = True + return has_new_finish + + def filter_finished(self) -> List[Req]: + """ + Filter finished requests from the batch, the finished ones will be removed from 'reqs'. + """ + # TODO: the logic of return should be defined here. + unfinished_req = [] + finished_req = [] + for req in self.reqs: + if not req.has_generate_finished: + unfinished_req.append(req) + else: + finished_req.append(req) + self.reqs = unfinished_req + self.id_to_reqs = {req.request_id: req for req in self.reqs} + return finished_req + + def is_clear(self): + return len(self.reqs) == 0 + + def merge(self, mini_batch): + for _req in mini_batch.reqs: + self.reqs.append(_req) + self.id_to_reqs = {req.request_id: req for req in self.reqs} + return + + def __repr__(self): + return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, " + + def __len__(self): + return len(self.reqs) + + +class BatchTokenIdOut: + def __init__(self): + self.reqs_infs: List[ + Tuple[str, int, Dict, bool, bool] + ] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state] + + +class BatchStrOut: + def __init__(self): + self.reqs_infs: List[ + Tuple[str, str, Dict, bool, bool] + ] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state] + + +class AbortReq: + def __init__(self, req_id): + self.req_id = req_id + + +class RequestOutput: + """The output data of a request to the LLM. + + Args: + request_id: The unique ID of the request. + prompt: The prompt string of the request. + prompt_token_ids: The token IDs of the prompt. + outputs: The output sequences of the request. + """ + + def __init__( + self, + request_id: str, + prompt: str, + prompt_token_ids: List[int], + outputs, + ) -> None: + self.request_id = request_id + self.prompt = prompt + self.prompt_token_ids = prompt_token_ids + self.outputs = outputs + + def __repr__(self) -> str: + return ( + f"RequestOutput(request_id={self.request_id}, " + f"prompt={self.prompt!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"outputs={self.outputs}, " + ) diff --git a/colossalai/legacy/inference/dynamic_batching/ray_dist_init.py b/colossalai/legacy/inference/dynamic_batching/ray_dist_init.py new file mode 100644 index 000000000000..70ef489d3b70 --- /dev/null +++ b/colossalai/legacy/inference/dynamic_batching/ray_dist_init.py @@ -0,0 +1,152 @@ +import logging +import os +from typing import List + +import ray +import ray.util.collective as collective +import torch +from transformers import AutoModelForCausalLM + +import colossalai +from colossalai.inference.async_manager import start_dynamic_batching +from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer +from colossalai.inference.dynamic_batching.io_struct import RequestOutput +from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import free_port + +ray_serve_logger = logging.getLogger("ray.serve") + + +def log_cuda_info(scope_name: str): + ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}") + ray_serve_logger.info( + f" {scope_name}: CUDA_VISIBLE_DEVICES: {os.getenv('CUDA_VISIBLE_DEVICES', 'NO DEVICES FOUND!')}" + ) + if torch.cuda.is_available(): + ray_serve_logger.info( + f" {scope_name}: cuda current_device: {torch.cuda.current_device()}, cuda device count: {torch.cuda.device_count()}" + ) + else: + ray_serve_logger.info(f" {scope_name}: cuda is not available!") + + +@ray.remote(num_gpus=1) +class Worker: + def __init__( + self, + model_path: str, + tensor_parallel_size: int, + max_batch_size: int, + max_input_len: int, + max_output_len: int, + router_config: RooterArgsClass, + ): + log_cuda_info("Worker.init") + self.tensor_parallel_size = tensor_parallel_size + self.model_path = model_path + self.max_batch_size = max_batch_size + self.max_input_len = max_input_len + self.max_output_len = max_output_len + self.router_config = router_config + + def setup(self, world_size, rank, port): + # initialize a ray collective group, otherwise colossalai distributed env won't be built successfully + collective.init_collective_group(world_size, rank, "nccl", "default") + # initialize and set distributed environment + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..") + log_cuda_info("Worker.setup") + + # Load model + self.tokenizer = get_tokenizer(tokenizer_name=self.model_path) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16 + ) + shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True) + self.infer_engine = TPInferEngine( + self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len + ) + self.start_dynamic_batching = start_dynamic_batching(self.router_config, self.infer_engine, []) + + return True + + # def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> List[str]: + # ray_serve_logger.info(f"text: {prompt}") + + # final_outputs = self.start_dynamic_batching.generate(prompt, sampling_params, request_id) + + # return final_outputs + + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): + self.start_dynamic_batching.add_input(request_id, prompt, sampling_params) + + def abort(self, request_id: str): + self.start_dynamic_batching.abort(request_id) + + def step(self) -> List[RequestOutput]: + return self.start_dynamic_batching._step() + + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): + self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt) + + def is_running(self): + return self.start_dynamic_batching.is_running() + + +class Driver: + def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClass): + log_cuda_info("Driver:init") + model_path = engine_config.model + tensor_parallel_size = engine_config.tensor_parallel_size + + self.num_workers = tensor_parallel_size + self.workers = [] + init_rets = [] + + # Just grab a free port on localhost + # NOTE workers in this communication group listen to the same port + available_port = free_port() + + for i in range(self.num_workers): + worker_name = "worker_idx_{}".format(i) + w = Worker.options(name=worker_name).remote( + model_path, + self.num_workers, + engine_config.max_batch_size, + engine_config.max_input_len, + engine_config.max_output_len, + router_config, + ) + self.workers.append(w) + init_rets.append(w.setup.remote(self.num_workers, i, available_port)) + _options = { + "group_name": "default_driver", + "world_size": self.num_workers, + "ranks": [i for i in range(self.num_workers)], + "backend": "nccl", + } + collective.create_collective_group(self.workers, **_options) + _ = ray.get(init_rets) + + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): + ray.get([w.add_input.remote(request_id, prompt, sampling_params) for w in self.workers]) + + def abort(self, request_id: str): + ray.get([w.abort.remote(request_id) for w in self.workers]) + + def step(self): + results = ray.get([w.step.remote() for w in self.workers]) + outputs = results[0] # get any one of the copies + return outputs + + def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompt: str): + ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) + + def is_running(self): + results = ray.get([w.is_running.remote() for w in self.workers]) + return any(results) diff --git a/colossalai/legacy/inference/dynamic_batching/ray_init_config.py b/colossalai/legacy/inference/dynamic_batching/ray_init_config.py new file mode 100644 index 000000000000..471f07330aec --- /dev/null +++ b/colossalai/legacy/inference/dynamic_batching/ray_init_config.py @@ -0,0 +1,58 @@ +import logging + +import yaml +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +class EngineArgsClass(BaseModel): + """Config for Engine""" + + model: str + tensor_parallel_size: int = 2 + max_batch_size: int = 4 + max_input_len: int = 128 + max_output_len: int = 32 + + +class RooterArgsClass(BaseModel): + """Config for Rooter""" + + max_total_token_num: int = 42 + batch_max_tokens: int = 42 + eos_id: int = 0 + disable_log_stats: bool = False + log_stats_interval: int = 10 + model: str + + +class RayInitConfig(BaseModel): + """All-together configs without app router config""" + + engine_config_data: EngineArgsClass + router_config_data: RooterArgsClass + + @classmethod + def from_yaml_path(cls, path: str): + try: + with open(path, "r") as yaml_file: + try: + config = yaml.safe_load(yaml_file) + # serve deployment config + engine_config = config.get("engine_config", {}) + router_config = config.get("router_config", {}) + + return cls( + engine_config_data=engine_config, + router_config_data=router_config, + ) + except yaml.YAMLError as e: + logger.error(f"An Error occurred when parsing yaml: {e}") + raise + except FileNotFoundError: + logger.error(f"The file '{path}' does not exist!") + raise + except OSError as e: + logger.error(f"An Error occurred: {e}") + raise diff --git a/colossalai/legacy/inference/dynamic_batching/req_queue.py b/colossalai/legacy/inference/dynamic_batching/req_queue.py new file mode 100644 index 000000000000..0de43bd1a21f --- /dev/null +++ b/colossalai/legacy/inference/dynamic_batching/req_queue.py @@ -0,0 +1,73 @@ +# Adapted from https://github.com/ModelTC/lightllm + +import uuid +from typing import List + +import numpy as np + +from .io_struct import Batch, Req + + +class ReqQueue: + def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size, waiting_req_list=[]) -> None: + self.max_total_tokens = max_total_tokens + assert batch_max_tokens is not None + self.batch_max_tokens = batch_max_tokens + self.running_max_req_size = running_max_req_size + self.waiting_req_list: List[Req] = waiting_req_list + + def append(self, req): + self.waiting_req_list.append(req) + return + + def _init_cache_list(self, current_batch: Batch): + if current_batch is not None: + self.cache_len_list = [ + (req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1) + for req in current_batch.reqs + ] + else: + self.cache_len_list = [] + + # @calculate_time(show=True, min_cost_ms=0.1) + def _can_add_new_req(self, req): + self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis + self.cache_len_list.sort(key=lambda x: -x[1]) + + left_out_len_array = np.array([e[1] for e in self.cache_len_list]) + # assert left_out_len_array.min() >= 0 + has_run_len_array = np.array([e[0] for e in self.cache_len_list]) + cum_run_len_array = np.cumsum(has_run_len_array) + size_array = np.arange(1, len(self.cache_len_list) + 1, 1) + + need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() + # NOTE: change here < to <= + return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size + + def generate_new_batch(self, current_batch: Batch = None): + if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size: + return None + self._init_cache_list(current_batch) + can_run_list = [] + new_batch_total_tokens = 0 + aborted_count = 0 + for req in self.waiting_req_list: + flag = self._can_add_new_req(req) + if req.aborted: + aborted_count += 1 + continue + if flag and new_batch_total_tokens + req.input_len <= self.batch_max_tokens: + can_run_list.append(req) + new_batch_total_tokens += req.input_len + else: + break + + if len(can_run_list) != 0: + new_batch = Batch(uuid.uuid4().hex, can_run_list) + self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] + return new_batch + else: + return None + + def __len__(self): + return self.waiting_req_list.__len__() diff --git a/colossalai/legacy/inference/dynamic_batching/sampling_params.py b/colossalai/legacy/inference/dynamic_batching/sampling_params.py new file mode 100644 index 000000000000..a37a83390021 --- /dev/null +++ b/colossalai/legacy/inference/dynamic_batching/sampling_params.py @@ -0,0 +1,83 @@ +# Adapted from https://github.com/ModelTC/lightllm + +"""Sampling parameters for text generation.""" +from typing import List, Optional, Union + +_SAMPLING_EPS = 1e-5 + + +class SamplingParams: + def __init__( + self, + do_sample: bool = False, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, # -1 is for all + ignore_eos: bool = False, + max_new_tokens: int = 256, + stop_sequences: Optional[Union[str, List[str]]] = None, # conditions to stop generation + ) -> None: + self.do_sample = do_sample + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.ignore_eos = ignore_eos + self.max_new_tokens = max_new_tokens + self.stop_sequences = stop_sequences + if self.do_sample == False: + self.temperature = 1.0 + self.top_p = 1.0 + self.top_k = 1 + if ( + self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS + ): # temperature is too slow, change to greedy search + self.temperature = 1.0 + self.top_k = 1 + return + + def verify(self): + if self.presence_penalty < 0.0: + raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}") + if self.frequency_penalty < 0.0: + raise ValueError(f"frequency_penalty must >= 0.0, got {self.frequency_penalty}") + if self.temperature <= 0.0: + raise ValueError(f"temperature must > 0.0, got {self.temperature}") + if self.top_p <= 0.0 or self.top_p > 1.0: + raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}") + if self.top_k < -1 or self.top_k == 0: + raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.") + if self.max_new_tokens < 1: + raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.") + return + + def stop_sentences_to_token_ids(self, tokenizer): + if self.stop_sequences is None: + self.stop_sequences = [] + else: + if isinstance(self.stop_sequences, str): + self.stop_sequences = [self.stop_sequences] + new_stop_sequences = [] + for stop_str in self.stop_sequences: + stop_str_ids = tokenizer.encode(stop_str) + if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id + stop_str_ids = stop_str_ids[1:] + if len(stop_str_ids) > 0: + new_stop_sequences.append(stop_str_ids) + self.stop_sequences = new_stop_sequences + return + + def to_dict(self): + ret = {} + ret["do_sample"] = self.do_sample + ret["presence_penalty"] = self.presence_penalty + ret["frequency_penalty"] = self.frequency_penalty + ret["temperature"] = self.temperature + ret["top_p"] = self.top_p + ret["top_k"] = self.top_k + # if self.ignore_eos is not None: + # ret["ignore_eos"] = self.ignore_eos + return ret diff --git a/colossalai/legacy/inference/dynamic_batching/stats.py b/colossalai/legacy/inference/dynamic_batching/stats.py new file mode 100644 index 000000000000..524072861a3f --- /dev/null +++ b/colossalai/legacy/inference/dynamic_batching/stats.py @@ -0,0 +1,45 @@ +# Adapted from https://github.com/ModelTC/lightllm + +import time + + +class Stats: + def __init__(self, log_status, log_stats_interval) -> None: + self.log_stats = log_status + self.log_stats_interval = log_stats_interval + self.last_log_time = time.time() + self.all_tokens = 0 + self.output_tokens = 0 + self.prompt_tokens = 0 + return + + def count_prompt_tokens(self, run_batch): + if self.log_stats: + tokens = run_batch.input_tokens() + self.prompt_tokens += tokens + self.all_tokens += tokens + return + + def count_output_tokens(self, run_batch): + if self.log_stats: + tokens = len(run_batch.reqs) + self.output_tokens += tokens + self.all_tokens += tokens + return + + def print_stats(self): + if not self.log_stats: + return + + now = time.time() + if now - self.last_log_time > self.log_stats_interval: + print( + f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n" + f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n" + f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s" + ) + self.all_tokens = 0 + self.output_tokens = 0 + self.prompt_tokens = 0 + self.last_log_time = now + return diff --git a/colossalai/legacy/inference/hybridengine/__init__.py b/colossalai/legacy/inference/hybridengine/__init__.py new file mode 100644 index 000000000000..6377ef817301 --- /dev/null +++ b/colossalai/legacy/inference/hybridengine/__init__.py @@ -0,0 +1,3 @@ +from .engine import CaiInferEngine + +__all__ = ["CaiInferEngine"] diff --git a/colossalai/legacy/inference/hybridengine/engine.py b/colossalai/legacy/inference/hybridengine/engine.py new file mode 100644 index 000000000000..bb0b4c77a2a7 --- /dev/null +++ b/colossalai/legacy/inference/hybridengine/engine.py @@ -0,0 +1,170 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from transformers.tokenization_utils_base import BatchEncoding + +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.schedule.generate import GenerateSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.base_policy import Policy + +from ..pipeline.microbatch_manager import MicroBatchManager +from ..tensor_parallel.kvcache_manager import MemoryManager + +PP_AXIS, TP_AXIS = 0, 1 + +_supported_models = [ + "LlamaForCausalLM", +] + + +class CaiInferEngine: + """ + CaiInferEngine is a class that handles the pipeline parallel inference. + + Args: + tp_size (int): the size of tensor parallelism. + pp_size (int): the size of pipeline parallelism. + model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`. + model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. + micro_batch_size (int): the micro batch size. + micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. + max_batch_size (int): the maximum batch size. + max_input_len (int): the maximum input length. + max_output_len (int): the maximum output length. + + Example: + + ```python + from colossalai.inference import InferEngine + from colossalai.inference.pipeline.policies import LlamaModelInferPolicy + import colossalai + from transformers import LlamaForCausalLM, LlamaTokenizer + + colossalai.launch_from_torch(config={}) + + model = LlamaForCausalLM.from_pretrained("your_path_to_model") + tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf") + # assume the model is infered with 2 pipeline stages + inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy()) + + input = ["Introduce a landmark in China ","Introduce a landmark in China "] + data = tokenizer(input, return_tensors='pt') + output = inferengine.inference([data.to('cuda').data]) + + ``` + + """ + + def __init__( + self, + tp_size: int = 1, + pp_size: int = 1, + dtype: str = "fp16", + model: nn.Module = None, + model_policy: Policy = None, + micro_batch_size: int = 1, + micro_batch_buffer_size: int = None, + max_batch_size: int = 4, + max_input_len: int = 32, + max_output_len: int = 32, + verbose: bool = False, + # TODO: implement early_stopping, and various gerneration options + early_stopping: bool = False, + do_sample: bool = False, + num_beams: int = 1, + ) -> None: + assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported." + assert ( + tp_size * pp_size == dist.get_world_size() + ), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})" + assert model and model_policy, "Model with model_policy should be provided." + assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'" + + assert max_batch_size <= 64, "Max batch size exceeds the constraint" + assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint" + + # TODO: support only tensor parallel inference + assert pp_size > 1, "Not support only tensor parallel inference." + self.pp_size = pp_size + self.tp_size = tp_size + + if dtype == "fp16": + self.dtype = torch.float16 + model.half() + elif dtype == "bf16": + self.dtype = torch.bfloat16 + model.to(torch.bfloat16) + else: + self.dtype = torch.float32 + + # Init pg mesh + pg_mesh = ProcessGroupMesh(pp_size, tp_size) + + stage_manager = None + if pp_size > 1: + stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True) + self.cache_manager_list = [ + self._init_manager(model, max_batch_size, max_input_len, max_output_len) + for _ in range(micro_batch_buffer_size or pp_size) + ] + self.mb_manager = MicroBatchManager( + stage_manager.stage, + micro_batch_size, + micro_batch_buffer_size or pp_size, + max_input_len, + max_output_len, + self.cache_manager_list, + ) + self.verbose = verbose + self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose) + + self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS)) + + def inference(self, input_list): + """ + Args: + input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`. + + Returns: + out (list): a list of output data, each element is a list of token. + timestamp (float): the time cost of the inference, only return when verbose is `True`. + """ + assert isinstance( + input_list, (BatchEncoding, dict) + ), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}." + if isinstance(input_list, BatchEncoding): + input_list = input_list.data + out, timestamp = self.schedule.generate_step(self.model, iter([input_list])) + if self.verbose: + return out, timestamp + else: + return out + + def _shardformer(self, model, model_policy, stage_manager, tp_group): + shardconfig = ShardConfig( + tensor_parallel_process_group=tp_group, + pipeline_stage_manager=stage_manager, + enable_tensor_parallelism=False, + enable_fused_normalization=False, + enable_all_optimization=False, + enable_flash_attention=False, + enable_jit_fused=False, + enable_sequence_parallelism=False, + ) + shardformer = ShardFormer(shard_config=shardconfig) + shard_model, _ = shardformer.optimize(model, model_policy) + return shard_model.cuda() + + def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None: + max_total_token_num = max_batch_size * (max_input_len + max_output_len) + head_dim = model.config.hidden_size // model.config.num_attention_heads + head_num = model.config.num_attention_heads + num_hidden_layers = ( + model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers + ) + layer_num = num_hidden_layers // self.pp_size + + cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num) + return cache_manager diff --git a/colossalai/legacy/inference/hybridengine/modeling/__init__.py b/colossalai/legacy/inference/hybridengine/modeling/__init__.py new file mode 100644 index 000000000000..239bdebd7efd --- /dev/null +++ b/colossalai/legacy/inference/hybridengine/modeling/__init__.py @@ -0,0 +1,3 @@ +from .llama import LlamaInferenceForwards + +__all__ = ["LlamaInferenceForwards"] diff --git a/colossalai/legacy/inference/hybridengine/modeling/_utils.py b/colossalai/legacy/inference/hybridengine/modeling/_utils.py new file mode 100644 index 000000000000..068b64b4f829 --- /dev/null +++ b/colossalai/legacy/inference/hybridengine/modeling/_utils.py @@ -0,0 +1,67 @@ +""" +Utils for model inference +""" +import os + +import torch + +from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest + + +def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + """ + This function copies the key and value cache to the memory cache + Args: + layer_id : id of current layer + key_buffer : key cache + value_buffer : value cache + context_mem_index : index of memory cache in kv cache manager + mem_manager : cache manager + """ + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + + +def init_to_get_rotary(self, base=10000, use_elem=False): + """ + This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer + Args: + self : Model that holds the rotary positional embedding + base : calculation arg + use_elem : activated when using chatglm-based models + """ + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None) + + if ntk_alpha is not None: + ntk_alpha = float(ntk_alpha) + assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1" + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula + + n_elem = self.config.head_dim_ + if use_elem: + n_elem //= 2 + + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() diff --git a/colossalai/legacy/inference/hybridengine/modeling/llama.py b/colossalai/legacy/inference/hybridengine/modeling/llama.py new file mode 100644 index 000000000000..34474d115c8f --- /dev/null +++ b/colossalai/legacy/inference/hybridengine/modeling/llama.py @@ -0,0 +1,489 @@ +# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py +import math +from typing import List, Optional, Tuple + +import torch +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel +from transformers.utils import logging + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd +from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards +from colossalai.pipeline.stage_manager import PipelineStageManager + +from ._utils import copy_kv_to_mem_cache + +try: + from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_llama2_context_attention_fwd, + ) + from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_context_attention_fwd, + ) + from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd + + HAS_LIGHTLLM_KERNEL = True +except: + print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") + HAS_LIGHTLLM_KERNEL = False + +try: + from flash_attn import flash_attn_with_kvcache + + HAS_FLASH_KERNEL = True +except: + HAS_FLASH_KERNEL = False + print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention") + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def llama_triton_context_attention( + query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1 +): + if num_key_value_groups == 1: + if HAS_LIGHTLLM_KERNEL is False: + llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + lightllm_context_attention_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model" + lightllm_llama2_context_attention_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + + +def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1): + assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models" + if num_key_value_groups == 1: + token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + Llama2TokenAttentionForwards.token_attn( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + infer_state.other_kv_index, + ) + + +class LlamaInferenceForwards: + """ + This class holds forwards for llama inference. + We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM. + """ + + @staticmethod + def llama_causal_lm_forward( + self: LlamaForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: BatchInferState = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + """ + logger = logging.get_logger(__name__) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # If is first stage and after warmup, go throught lm_head first + if stage_manager.is_first_stage() and hidden_states is not None: + lm_logits = self.lm_head(hidden_states) + return {"logits": lm_logits} + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = LlamaInferenceForwards.llama_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + infer_state=infer_state, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + + return outputs + + @staticmethod + def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: BatchInferState = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + use_cache = use_cache if use_cache is not None else self.config.use_cache + # retrieve input_ids and inputs_embeds + if stage_manager is None or stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + assert stage_manager is not None + assert hidden_states is not None, f"hidden_state should not be none in stage {stage_manager.stage}" + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 + + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + if use_cache and seq_length != 1: + # NOTE assume prefill stage + # allocate memory block + infer_state.is_context_stage = True # set prefill stage, notify attention layer + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + infer_state.is_context_stage = False + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index + else: + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index + + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.repeat(batch_size, 1) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if infer_state.is_context_stage: + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + + else: + seq_len = infer_state.seq_len + infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=hidden_states.device + ) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) + + # decoder layers + infer_state.decode_layer_id = 0 + + start_idx, end_idx = stage_index[0], stage_index[1] + if past_key_values is None: + past_key_values = tuple([None] * (end_idx - start_idx + 1)) + + for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values): + decoder_layer = self.layers[idx] + # NOTE: modify here for passing args to decoder layer + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_state=infer_state, + ) + infer_state.decode_layer_id += 1 + hidden_states = layer_outputs[0] + + if stage_manager.is_last_stage() or stage_manager.num_stages == 1: + hidden_states = self.norm(hidden_states) + + # update indices + # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + infer_state.max_len_in_batch += 1 + + # if not return_dict: + # return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + # return BaseModelOutputWithPast( + # last_hidden_state=hidden_states, + # past_key_values=next_cache, + # hidden_states=all_hidden_states, + # attentions=all_self_attns, + # ) + return {"hidden_states": hidden_states} + + @staticmethod + def llama_decoder_layer_forward( + self: LlamaDecoderLayer, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_state=infer_state, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + @staticmethod + def llama_flash_attn_kvcache_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + assert use_cache is True, "use_cache should be set to True using this llama attention" + + bsz, q_len, _ = hidden_states.size() + + # NOTE might think about better way to handle transposed k and v + # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head] + # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + # NOTE might want to revise + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + + cos, sin = infer_state.position_cos, infer_state.position_sin + + llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) + llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin) + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) + key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim) + + if infer_state.is_context_stage: + # first token generation + # copy key and value calculated in current step to memory manager + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.context_mem_index, + infer_state.cache_manager, + ) + attn_output = torch.empty_like(query_states) + + llama_triton_context_attention( + query_states, + key_states, + value_states, + attn_output, + infer_state, + num_key_value_groups=self.num_key_value_groups, + ) + else: + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(key_states) + cache_v.copy_(value_states) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) + + if HAS_LIGHTLLM_KERNEL: + attn_output = torch.empty_like(query_states) + llama_triton_token_attention( + query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups + ) + else: + self.num_heads // self.num_key_value_heads + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] + + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) + copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) + copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) + + attn_output = flash_attn_with_kvcache( + q=query_states, + k_cache=copy_cache_k, + v_cache=copy_cache_v, + softmax_scale=1 / math.sqrt(self.head_dim), + causal=True, + ) + + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + # return past_key_value as None + return attn_output, None, None diff --git a/colossalai/legacy/inference/hybridengine/polices/__init__.py b/colossalai/legacy/inference/hybridengine/polices/__init__.py new file mode 100644 index 000000000000..7271812c5366 --- /dev/null +++ b/colossalai/legacy/inference/hybridengine/polices/__init__.py @@ -0,0 +1,3 @@ +from .llama import LlamaModelInferPolicy + +__all__ = ["LlamaModelInferPolicy"] diff --git a/colossalai/legacy/inference/hybridengine/polices/llama.py b/colossalai/legacy/inference/hybridengine/polices/llama.py new file mode 100644 index 000000000000..992299714bd1 --- /dev/null +++ b/colossalai/legacy/inference/hybridengine/polices/llama.py @@ -0,0 +1,142 @@ +from functools import partial +from typing import List + +import torch +from torch.nn import Module +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, +) + +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription + +# import colossalai +from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy + +from ..modeling._utils import init_to_get_rotary +from ..modeling.llama import LlamaInferenceForwards + +try: + from colossalai.kernel.triton import rmsnorm_forward + + HAS_TRITON_RMSNORM = True +except: + print("you should install triton from https://github.com/openai/triton") + HAS_TRITON_RMSNORM = False + + +def get_triton_rmsnorm_forward(): + if HAS_TRITON_RMSNORM: + + def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) + + return _triton_rmsnorm_forward + else: + return None + + +class LlamaModelInferPolicy(LlamaForCausalLMPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + + if self.shard_config.inference_gptq: + from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear + + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=RowCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=RowCaiQuantLinear, + kwargs={"split_num": 1}, + ), + ], + ) + + self.shard_config._infer() + + infer_forward = LlamaInferenceForwards.llama_model_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) + + infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaDecoderLayer + ) + + infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaAttention + ) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=LlamaForCausalLM, new_forward=LlamaInferenceForwards.llama_causal_lm_forward, policy=policy + ) + infer_forward = None + if HAS_TRITON_RMSNORM: + infer_forward = get_triton_rmsnorm_forward() + + if infer_forward is not None: + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaRMSNorm + ) + + return policy + + def postprocess(self): + init_to_get_rotary(self.model.model) + return self.model + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_first_stage(): + held_layers.append(self.model.lm_head) + return held_layers diff --git a/colossalai/legacy/inference/manager.py b/colossalai/legacy/inference/manager.py new file mode 100644 index 000000000000..9672a50141a0 --- /dev/null +++ b/colossalai/legacy/inference/manager.py @@ -0,0 +1,296 @@ +# Adapted from https://github.com/ModelTC/lightllm + +import time +from typing import List + +from .dynamic_batching.get_tokenizer import get_tokenizer +from .dynamic_batching.infer_batch import InferBatch +from .dynamic_batching.io_struct import Batch, Req +from .dynamic_batching.req_queue import ReqQueue +from .dynamic_batching.sampling_params import SamplingParams +from .dynamic_batching.stats import Stats +from .tensor_parallel import TPInferEngine + + +class DynamicBatchManager: + def __init__( + self, + tp_engine: TPInferEngine, + max_total_token_num, + batch_max_tokens, + model, + tokenizer=None, + eos_id=None, + log_stats=True, + log_stats_interval=10, + running_batch: Batch = None, + waiting_req_list: List = [], + ): + """ + Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager + max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) + batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests + running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine + eos_id : The end token of a seq + model: the model weight dir path, the app will load config, weights and tokenizer from this dir + log_stats : whether to log stats + log_stats_interval : log stats interval + running_batch : running batch + waiting_req_list : list of waiting requests, initialized before dynamic batch manager + """ + self.engine = tp_engine + self.max_total_token_num = max_total_token_num + running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2 + self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list) + # all the inputs should be put into req_queue: waiting req list + assert max_total_token_num >= self.engine.max_batch_size * ( + self.engine.max_input_len + self.engine.max_output_len + ), "max_total_token_num should be greater than max_batch_size * (max_input_len+max_output_len)" + assert ( + batch_max_tokens >= self.engine.max_input_len + self.engine.max_output_len + ), "batch_max_tokens should be greater than (max_input_len+max_output_len)" + self.running_batch: Batch = running_batch + self.eos_id = eos_id + self.has_wait_tokens = 0 + self.max_wait_tokens = 10 + self.model = model + + self.stats_tool = Stats(log_stats, log_stats_interval) + self.mem_usage_interval = log_stats_interval * 2 + self.tokenizer = get_tokenizer(tokenizer_name=self.model) if tokenizer is None else tokenizer + if self.eos_id == None: + self.eos_id = self.tokenizer.eos_token_id + + def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""): + """ + Add new request to req queue, during initialization all requests are held in waiting list. + """ + sampling_params.max_new_tokens = ( + self.engine.max_output_len + if sampling_params.max_new_tokens > self.engine.max_output_len + else sampling_params.max_new_tokens + ) + req = Req(request_id, prompt_ids, sampling_params, prompts) + self.req_queue.append(req) + return + + def add_input(self, request_id, prompts, sampling_params): + """ + Encode and Add new input to req queue. support one sequence input for now. + """ + prompt_ids = self.tokenizer.encode(prompts) + prompt_len = len(prompt_ids) + if prompt_len > self.engine.max_input_len: + raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}") + sampling_params.stop_sentences_to_token_ids(self.tokenizer) + self.add_req(request_id, prompt_ids, sampling_params, prompts) + return + + def abort(self, request_id): + if self.running_batch is not None: + for req in self.running_batch.reqs: + if req.request_id == request_id: + req.has_generate_finished = True + req.aborted = True + for req in self.req_queue.waiting_req_list: + if req.request_id == request_id: + req.has_generate_finished = True + req.aborted = True + return + + def loop_for_fwd(self): + """ + The main loop for a dynamic batching process. + """ + counter_count = 0 + # self.running_batch is not None or self.req_queue.waiting_req_list + while self.running_batch is not None or self.req_queue.waiting_req_list: + yield from self._step() + counter_count += 1 + if self.running_batch is not None: + if counter_count % self.mem_usage_interval == 0: + print( + "current batch size:", + len(self.running_batch.reqs), + "token used ratio:", + self.running_batch.calcu_used_tokens() / self.max_total_token_num, + ) + self.stats_tool.print_stats() + + if self.running_batch is None: + time.sleep(0.1) # 10ms + + def _step(self): + """ + Logic for handling requests + """ + + if self.running_batch is None: + new_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_batch is not None: + self.stats_tool.count_prompt_tokens(new_batch) + self.running_batch = new_batch + yield from self._prefill_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens = 0 + return + + if self.has_wait_tokens < self.max_wait_tokens: + self.stats_tool.count_output_tokens(self.running_batch) + yield from self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + return + else: + new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_mini_batch is not None: + self.stats_tool.count_prompt_tokens(new_mini_batch) + yield from self._prefill_batch(new_mini_batch) + if not new_mini_batch.is_clear(): + self._merge_batch(self.running_batch, new_mini_batch) + self.running_batch.merge(new_mini_batch) + self.has_wait_tokens = 0 + + else: + self.stats_tool.count_output_tokens(self.running_batch) + yield from self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + + return + + def _init_batch(self, batch: Batch, dtype="fp16"): + reqs = [r.to_rpc_obj() for r in batch.reqs] + batch_id = batch.batch_id + + import torch + + if dtype == "fp16": + dtype = torch.float16 + else: + assert False, "error dtype" + + batch_data = InferBatch.init_batch( + batch_id, + reqs, + dtype, + torch.cuda.current_device(), + self.engine.cache_manager, + self.engine.model.config.vocab_size, + self.engine.max_input_len + self.engine.max_output_len, + ) + self.engine.cache[batch_id] = batch_data + + def _prefill_batch(self, batch): + """ + For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. + """ + self._init_batch(batch) + + # TODO: figure out if cache and batch id is needed + ans = self.engine._prefill_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) + yield from self._handle_finish_req(batch, has_new_finished_req) + + # delete finished reqs + + def _decode_batch(self, batch: Batch): + """ + Decoding process + """ + ans = self.engine._decode_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) + yield from self._handle_finish_req(batch, has_new_finished_req) + + def _filter_batch(self, batch: Batch): + batch_id = batch.batch_id + req_id_list = [r.request_id for r in batch.reqs] + batch = self.engine.cache.pop(batch_id) + filter_batch = batch.filter(req_id_list) + del batch + self.engine.cache[batch_id] = filter_batch + + def _merge_batch(self, batch1, batch2): + """ + Merge new mini batch into running batch. + """ + batch1 = self.engine.cache.pop(batch1.batch_id) + batch2 = self.engine.cache.pop(batch2.batch_id) + + m_batch = InferBatch.merge(batch1, batch2) + self.engine.cache[batch1.batch_id] = m_batch + del batch1 + del batch2 + + def _remove_batch(self, batch): + """ + Remove finished batch. + """ + batch = self.engine.cache.pop(batch.batch_id) + batch.free_self() + del batch + + def _handle_finish_req(self, batch: Batch, has_new_finished_req): + if has_new_finished_req: + finished_reqs = batch.filter_finished() + if batch.is_clear(): + self._remove_batch(batch) + else: + self._filter_batch(batch) + yield from self._output_process(finished_reqs) + + def _filter_runing_batch(self): + if self.running_batch is not None and self.running_batch.is_clear(): + self.running_batch = None + + def _add_token_id_to_req(self, batch: Batch, req_ans): + for req_id, (new_token_id, new_gen_metadata) in req_ans.items(): + req = batch.id_to_reqs[req_id] + req.output_ids.append(new_token_id) + req.output_metadata_list.append(new_gen_metadata) + return + + def _output_process(self, finished_reqs: List[Req]): + """ + Process the output of a batch. + """ + for req in finished_reqs: + output = self.tokenizer.decode(req.output_ids) + yield req.prompts + output + + def clean_up(self): + # this logic should be implemented in the future. + pass + + def generate(self, request_id, prompts, sampling_params): + """ + Generate the output of a request. + """ + self.add_input(request_id, prompts, sampling_params) + return self.loop_for_fwd() + + def is_running(self): + return self.running_batch is not None or self.req_queue.waiting_req_list + + +def start_dynamic_batching(args, tp_engine, waiting_req_list): + try: + batch_manager = DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + model=args.model, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + except Exception: + raise Exception + + return batch_manager diff --git a/colossalai/legacy/inference/pipeline/README.md b/colossalai/legacy/inference/pipeline/README.md new file mode 100644 index 000000000000..f9bb35cc4d4c --- /dev/null +++ b/colossalai/legacy/inference/pipeline/README.md @@ -0,0 +1,83 @@ +# 🐳 Pipeline Inference + +## Table of Contents +- [💡 Introduction](#introduction) +- [🔗 Design](#design) +- [🔨 Usage](#usage) + - [Example](#example) + - [Quick start](#quick-start) +- [📊 Performance](#performance) + +## Introduction + +`Pipeline Inference` is a module designed to make inference on a pipeline way. In inference systems, although there is no need to store intermediate information such as activations during forward propagation for backward propagation, the weights of some larger models still cannot fit on a single GPU for inference. This requires us to use model parallelism and other methods to reduce the memory occupation on a single GPU. Pipeline parallelism, as one of the traditional model parallelism approaches, has been widely used due to its reduced all-reduce communication requirements and simple layout. The main issue with pipeline parallelism, known as bubbles, can be almost eliminated in inference because the backward propagation that causes bubbles no longer exists in inference. This makes pipeline parallelism almost bubble-free in the ideal scenario where the sequence length is the same across the pipeline. + +## Design + +Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManager` and `generate` [schedule](https://github.com/hpcaitech/ColossalAI/blob/feature/pipeline-infer/colossalai/pipeline/schedule/generate.py). + +1. `PPInderEngine` is the High-Level API for users to use. It is responsible for the following tasks: + - Initialize the pipeline inference environment with `PipelineStageManager` and model with `ShardFormer`. + - Run the pipeline inference model. + +2. `MicroBatchManager` is a structure to manage the micro-batch information. It is responsible for the following tasks: + - Record each micro-batch information, like generated new tokens and kvcache. + - Record each micro-batch inference state, like prefill, generate or done. + - Update the micro-batch information. + +3. `generate` schedule implements the simple pipeline inference layout. When pipeline size is 2, we use `torch.distributed.P2Pop` to implement the communication between stages, mainly to solve the race communication. When pipeline size is larger than 2, we use `torch.distributed.broadcast` which is faster than `torch.distributed.P2Pop`. + +## Usage + +### Example +```python +from colossalai.inference import PPInferEngine +from colossalai.inference.pipeline.policies import LlamaModelInferPolicy +import colossalai +from transformers import LlamaForCausalLM, LlamaTokenizer + +colossalai.launch_from_torch(config={}) + +model = LlamaForCausalLM.from_pretrained("/path/to/model") +tokenizer = LlamaTokenizer.from_pretrained("/path/to/model") + +# assume the model is inferred with 2 pipeline stages +inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=32) + +input = ["Introduce a landmark in London","Introduce a landmark in Singapore"] +data = tokenizer(input, return_tensors='pt') +output = inferengine.inference(data.to('cuda')) +print(tokenizer.batch_decode(output)) +``` + +## Performance + +We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2 * A10, 20G / 2 * A800, 80G. + +### Llama Throughput (tokens/s) | input length=1024, output length=128 + +#### A10 7b, fp16 +| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)| +| :---: | :---: | :---: | :---: | :---: | :---: | :---:| +| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM | +| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM | + +#### A10 13b, fp16 +| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) | +| :---: | :---: | :---: | :---: | :---: | +| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 | +| Hugging Face | 23.48 | 37.59 | 53.44 | OOM | + + +#### A800 7b, fp16 +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | +| :---: | :---: | :---: | :---: | :---: | :---: | +| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 | +| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 | + + +#### A800 13b, fp16 +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | +| :---: | :---: | :---: | :---: | :---: | :---: | +| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 | +| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 | diff --git a/colossalai/legacy/inference/pipeline/__init__.py b/colossalai/legacy/inference/pipeline/__init__.py new file mode 100644 index 000000000000..f43e4a847448 --- /dev/null +++ b/colossalai/legacy/inference/pipeline/__init__.py @@ -0,0 +1,3 @@ +from .microbatch_manager import MicroBatchManager + +__all__ = ["MicroBatchManager"] diff --git a/colossalai/legacy/inference/pipeline/benchmark/benchmark.py b/colossalai/legacy/inference/pipeline/benchmark/benchmark.py new file mode 100644 index 000000000000..8392d0a1e579 --- /dev/null +++ b/colossalai/legacy/inference/pipeline/benchmark/benchmark.py @@ -0,0 +1,134 @@ +import argparse +import time + +import torch +import torch.distributed as dist +import transformers + +import colossalai +from colossalai.inference import PPInferEngine +from colossalai.inference.pipeline.policies import LlamaModelInferPolicy + +GIGABYTE = 1024**3 +MEGABYTE = 1024 * 1024 + +colossalai.launch_from_torch(config={}) + + +def data_gen(batch_size: int = 4, seq_len: int = 512): + input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32) + attention_mask = torch.ones((1, seq_len), dtype=torch.int32) + data = dict(input_ids=input_ids, attention_mask=attention_mask) + for k, v in data.items(): + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = batch_size + data[k] = v.to("cuda").repeat(*new_shape) + return data + + +def print_details_info(timestamps, model_config, args, whole_end2end): + if dist.get_rank() == 0: + prefill = [] + encoder = [] + end2end = [] + for timestamp in timestamps: + prefill.append(timestamp[1] - timestamp[0]) + encoder.append( + sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2) + ) + end2end.append(timestamp[-1] - timestamp[0]) + print(whole_end2end) + with open( + f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log", + "w+", + ) as f: + mb_avg_end2end = sum(end2end) / len(end2end) + mb_avg_latency = mb_avg_end2end / (args.new_length * args.mb_size) + whole_avg_latency = whole_end2end / (args.new_length * args.batch_size) + num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) + num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size + if args.dtype in ["fp16", "bf16"]: + num_bytes = 2 + else: + num_bytes = 4 + + f.write( + f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n" + ) + f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill) / len(prefill) * 1000)) + f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder) / len(encoder) * 1000)) + f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end * 1000)) + f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000)) + f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end * 1000)) + f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000)) + f.write("Throughput: {} tokens/s\n".format((1000 / (whole_avg_latency * 1000)))) + f.write("flops: {0:8.2f} TFlops/s\n".format(1 / whole_avg_latency * num_parameters * num_bytes / 1e12)) + f.write("----------------------------------------------------------\n") + + if torch.cuda.is_available(): + current_device = torch.cuda.current_device() + + # free memory and the total available memory in bytes + global_free_memory, total_GPU_memory_occupied = torch.cuda.mem_get_info() + memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + memory_reserved = torch.cuda.memory_reserved() + max_memory_reserved = torch.cuda.max_memory_reserved() + with open( + f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log", + "a", + ) as f: + f.write( + f"\nCurrently using GPU: {current_device}\n" + f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n" + f"total memory: {total_GPU_memory_occupied / GIGABYTE:.4f} GB,\n" + f"memory allocated: {memory_allocated / GIGABYTE:.4f} GB,\n" + f"Max CUDA memory allocated: {max_memory_allocated / GIGABYTE:.4f} GB,\n" + f"memory reserved/cached: {memory_reserved / GIGABYTE:.4f} GB,\n" + f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="toy", help="the size of model") + parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size") + parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length") + parser.add_argument("--new_length", type=int, default=4, help="new tokens length") + parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size") + parser.add_argument("--pp_size", type=int, default=2, help="pipeline size") + parser.add_argument("--log_path", type=str, default="./log", help="where to store the benchmark log") + parser.add_argument("--dtype", type=str, default="fp16", help="data type") + args = parser.parse_args() + + if args.model == "toy": + model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8)) + elif args.model == "7b": + model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-7b-hf")) + elif args.model == "13b": + model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-13b-hf")) + else: + raise NotImplementedError + + engine = PPInferEngine( + pp_size=args.pp_size, + dtype=args.dtype, + micro_batch_size=args.mb_size, + new_length=args.new_length, + model=model, + model_policy=LlamaModelInferPolicy(), + verbose=True, + max_batch_size=args.mb_size, + max_input_len=args.seq_len, + max_output_len=args.seq_len + args.new_length + 256, + ) + data = data_gen(args.batch_size, args.seq_len) + + torch.cuda.synchronize() + whole_end2end = time.time() + output, timestamps = engine.inference([data]) + torch.cuda.synchronize() + whole_end2end = time.time() - whole_end2end + + print_details_info(timestamps, model.config, args, whole_end2end) diff --git a/colossalai/legacy/inference/pipeline/benchmark/run.sh b/colossalai/legacy/inference/pipeline/benchmark/run.sh new file mode 100644 index 000000000000..e3c33bb88db1 --- /dev/null +++ b/colossalai/legacy/inference/pipeline/benchmark/run.sh @@ -0,0 +1,50 @@ +script_dir=$(cd "$(dirname "$0")" && pwd) +cd "${script_dir}" + +# 7b, fp16, 2 gpu, 1024, 128 +for BATCH_SIZE in 2 4 8 16; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="7b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=1024 \ + --new_length=128 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done + +# 7b, fp16, 2 gpu, 512, 512 +for BATCH_SIZE in 2 4 8 16 32; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="7b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=512 \ + --new_length=512 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done + +# 7b, fp16, 2 gpu, 1024, 128 +for BATCH_SIZE in 2 4 8; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="13b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=1024 \ + --new_length=128 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done + +# 13b, fp16, 2 gpu, 512, 512 +for BATCH_SIZE in 2 4 8 16; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="13b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=512 \ + --new_length=512 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done diff --git a/colossalai/legacy/inference/pipeline/microbatch_manager.py b/colossalai/legacy/inference/pipeline/microbatch_manager.py new file mode 100644 index 000000000000..441cf603985c --- /dev/null +++ b/colossalai/legacy/inference/pipeline/microbatch_manager.py @@ -0,0 +1,249 @@ +from enum import Enum +from typing import Dict + +import torch + +from ..tensor_parallel.batch_infer_state import BatchInferState +from ..tensor_parallel.kvcache_manager import MemoryManager + +__all__ = "MicroBatchManager" + + +class Status(Enum): + PREFILL = 1 + GENERATE = 2 + DONE = 3 + COOLDOWN = 4 + + +class MicroBatchDescription: + """ + This is the class to record the infomation of each microbatch, and also do some update operation. + This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more + details, please refer to the doc of these two classes blow. + + Args: + inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. + output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. + """ + + def __init__( + self, + inputs_dict: Dict[str, torch.Tensor], + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, + ) -> None: + self.mb_length = inputs_dict["input_ids"].shape[-1] + self.target_length = self.mb_length + max_output_len + self.infer_state = BatchInferState.init_from_batch( + batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager + ) + # print(f"[init] {inputs_dict}, {max_input_len}, {max_output_len}, {cache_manager}, {self.infer_state}") + + def update(self, *args, **kwargs): + pass + + @property + def state(self): + """ + Return the state of current micro batch, when current length is equal to target length, + the state is DONE, otherwise GENERATE + + """ + # TODO: add the condition for early stopping + if self.cur_length == self.target_length: + return Status.DONE + elif self.cur_length == self.target_length - 1: + return Status.COOLDOWN + else: + return Status.GENERATE + + @property + def cur_length(self): + """ + Return the current sequnence length of micro batch + + """ + + +class HeadMicroBatchDescription(MicroBatchDescription): + """ + This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask` + and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the + information and the condition to determine the state is different from other stages. + + Args: + inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. + output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. + + """ + + def __init__( + self, + inputs_dict: Dict[str, torch.Tensor], + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, + ) -> None: + super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager) + assert inputs_dict is not None + assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None + self.input_ids = inputs_dict["input_ids"] + self.attn_mask = inputs_dict["attention_mask"] + self.new_tokens = None + + def update(self, new_token: torch.Tensor = None): + if new_token is not None: + self._update_newtokens(new_token) + if self.state is not Status.DONE and new_token is not None: + self._update_attnmask() + + def _update_newtokens(self, new_token: torch.Tensor): + if self.new_tokens is None: + self.new_tokens = new_token + else: + self.new_tokens = torch.cat([self.new_tokens, new_token], dim=-1) + + def _update_attnmask(self): + self.attn_mask = torch.cat( + (self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1 + ) + + @property + def cur_length(self): + """ + When there is no new_token, the length is mb_length, otherwise the sequence length is `mb_length` plus the length of new_token + + """ + if self.new_tokens is None: + return self.mb_length + else: + return self.mb_length + len(self.new_tokens[0]) + + +class BodyMicroBatchDescription(MicroBatchDescription): + """ + This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`, + + Args: + inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage. + """ + + def __init__( + self, + inputs_dict: Dict[str, torch.Tensor], + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, + ) -> None: + super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager) + + @property + def cur_length(self): + """ + When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1 + + """ + return self.infer_state.seq_len.max().item() + + +class MicroBatchManager: + """ + MicroBatchManager is a class that manages the micro batch. + + Args: + stage (int): stage id of current stage. + micro_batch_size (int): the micro batch size. + micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. + + """ + + def __init__( + self, + stage: int, + micro_batch_size: int, + micro_batch_buffer_size: int, + max_input_len: int, + max_output_len: int, + cache_manager_list: MemoryManager, + ): + self.stage = stage + self.micro_batch_size = micro_batch_size + self.buffer_size = micro_batch_buffer_size + self.max_input_len = max_input_len + self.max_output_len = max_output_len + self.cache_manager_list = cache_manager_list + self.mb_descrption_buffer = {} + self.new_tokens_buffer = {} + self.idx = 0 + + def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]): + if self.stage == 0: + self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription( + inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] + ) + else: + self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription( + inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] + ) + + def step(self, new_token: torch.Tensor = None): + """ + Update the state if microbatch manager, 2 conditions. + 1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs. + 2. For other conditon, only receive the output of previous stage, and update the descrption. + + Args: + inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. + output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. + new_token (torch.Tensor): the new token generated by current stage. + """ + # Add descrption first if the descrption is None + self.cur_descrption.update(new_token) + return self.cur_state + + def export_new_tokens(self): + new_tokens_list = [] + for i in self.mb_descrption_buffer.values(): + new_tokens_list.extend(i.new_tokens.tolist()) + return new_tokens_list + + def is_micro_batch_done(self): + if len(self.mb_descrption_buffer) == 0: + return False + for mb in self.mb_descrption_buffer.values(): + if mb.state != Status.DONE: + return False + return True + + def clear(self): + self.mb_descrption_buffer.clear() + for cache in self.cache_manager_list: + cache.free_all() + + def next(self): + self.idx = (self.idx + 1) % self.buffer_size + + def _remove_descrption(self): + self.mb_descrption_buffer.pop(self.idx) + + @property + def cur_descrption(self) -> MicroBatchDescription: + return self.mb_descrption_buffer.get(self.idx) + + @property + def cur_infer_state(self): + if self.cur_descrption is None: + return None + return self.cur_descrption.infer_state + + @property + def cur_state(self): + """ + Return the state of current micro batch, when current descrption is None, the state is PREFILL + + """ + if self.cur_descrption is None: + return Status.PREFILL + return self.cur_descrption.state diff --git a/colossalai/legacy/inference/quant/gptq/__init__.py b/colossalai/legacy/inference/quant/gptq/__init__.py new file mode 100644 index 000000000000..c035f397923a --- /dev/null +++ b/colossalai/legacy/inference/quant/gptq/__init__.py @@ -0,0 +1,4 @@ +from .cai_gptq import HAS_AUTO_GPTQ + +if HAS_AUTO_GPTQ: + from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear diff --git a/colossalai/legacy/inference/quant/gptq/cai_gptq/__init__.py b/colossalai/legacy/inference/quant/gptq/cai_gptq/__init__.py new file mode 100644 index 000000000000..4ed76293bd81 --- /dev/null +++ b/colossalai/legacy/inference/quant/gptq/cai_gptq/__init__.py @@ -0,0 +1,14 @@ +import warnings + +HAS_AUTO_GPTQ = False +try: + import auto_gptq + + HAS_AUTO_GPTQ = True +except ImportError: + warnings.warn("please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ") + HAS_AUTO_GPTQ = False + +if HAS_AUTO_GPTQ: + from .cai_quant_linear import CaiQuantLinear, ColCaiQuantLinear, RowCaiQuantLinear + from .gptq_op import CaiGPTQLinearOp diff --git a/colossalai/legacy/inference/quant/gptq/cai_gptq/cai_quant_linear.py b/colossalai/legacy/inference/quant/gptq/cai_gptq/cai_quant_linear.py new file mode 100644 index 000000000000..36339ac88486 --- /dev/null +++ b/colossalai/legacy/inference/quant/gptq/cai_gptq/cai_quant_linear.py @@ -0,0 +1,354 @@ +# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ + +import math +import warnings +from typing import List, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup + +from colossalai.lazy import LazyInitContext +from colossalai.shardformer.layer import ParallelModule + +from .gptq_op import CaiGPTQLinearOp + +HAS_GPTQ_CUDA = False +try: + from colossalai.kernel.op_builder.gptq import GPTQBuilder + + gptq_cuda = GPTQBuilder().load() + HAS_GPTQ_CUDA = True +except ImportError: + warnings.warn("CUDA gptq is not installed") + HAS_GPTQ_CUDA = False + + +class CaiQuantLinear(nn.Module): + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): + super().__init__() + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize if groupsize != -1 else infeatures + + self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer( + "qzeros", + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32), + ) + self.register_buffer( + "scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16) + ) + if row_split: + self.register_buffer( + "g_idx", + torch.tensor( + [(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], dtype=torch.int32 + ), + ) + else: + self.register_buffer( + "g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32) + ) + + if bias: + self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) + else: + self.bias = None + + self.gptq_linear = CaiGPTQLinearOp(groupsize, bits) + + self.q4 = None + self.empty_tensor = torch.empty((1, 1), device="meta") + self.tp_size = tp_size + self.tp_rank = tp_rank + self.row_split = row_split + + def pack(self, linear, scales, zeros, g_idx=None): + g_idx = ( + g_idx.clone() + if g_idx is not None + else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) + ) + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + half_scales = scales.clone().half() + # print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape) + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + pbits = 32 + ptype = torch.int32 + unsign_type = np.uint32 + sign_type = np.int32 + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[ + :, None + ] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(unsign_type) + qweight = np.zeros((intweight.shape[0] // pbits * self.bits, intweight.shape[1]), dtype=unsign_type) + + i = 0 + row = 0 + + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (pbits // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += pbits // self.bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qweight = qweight.astype(sign_type) + qweight1 = torch.from_numpy(qweight) + qweight1 = qweight1.contiguous() # .to("cuda") + self.qweight.data.copy_(qweight1) + + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type) + zeros -= 1 + zeros = zeros.numpy().astype(unsign_type) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (pbits // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += pbits // self.bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + qzeros = qzeros.astype(sign_type) + qzeros = torch.from_numpy(qzeros) + qzeros = qzeros + self.qzeros.data.copy_(qzeros) + + if torch.equal(self.g_idx.to(g_idx.device), g_idx): + self.g_idx = None + else: + self.g_idx = g_idx + + def init_q4(self): + assert self.qweight.device.type == "cuda" + self.q4_width = self.qweight.shape[1] + if self.g_idx is not None: + if self.row_split and torch.equal( + self.g_idx, + torch.tensor( + [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], + dtype=torch.int32, + device=self.g_idx.device, + ), + ): + self.g_idx = None + elif torch.equal( + self.g_idx, + torch.tensor( + [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device + ), + ): + self.g_idx = None + + if self.g_idx is not None: + g_idx = self.g_idx.to("cpu") + else: + g_idx = self.empty_tensor + + self.q4 = gptq_cuda.make_q4(self.qweight, self.qzeros, self.scales, g_idx, torch.cuda.current_device()) + torch.cuda.synchronize() + + def forward(self, x): + outshape = x.shape[:-1] + (self.outfeatures,) + + if HAS_GPTQ_CUDA and self.bits == 4: + if self.q4 is None: + self.init_q4() + + x = x.view(-1, x.shape[-1]) + output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device) + gptq_cuda.q4_matmul(x.half(), self.q4, output) + if self.bias is not None and (not self.row_split or self.tp_size == 1): + output.add_(self.bias) + else: + if self.bias is not None and (not self.row_split or self.tp_size == 1): + bias = self.bias + else: + bias = None + output = self.gptq_linear( + x, + self.qweight, + self.scales, + self.qzeros, + g_idx=self.g_idx, + bias=bias, + ) + return output.view(outshape) + + +def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1): + qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1) + qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1) + scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1) + g_idx = gptq_linear.g_idx + if gptq_linear.bias is not None: + bias = gptq_linear.bias.split(gptq_linear.out_features // split_num, dim=-1) + + cai_split_out_features = cai_linear.outfeatures // split_num + zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num + + for i in range(split_num): + cai_linear.qweight[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = qweights[i][ + :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features + ] + cai_linear.qzeros[:, i * zero_split_block : (i + 1) * zero_split_block] = qzeros[i][ + :, tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block + ] + cai_linear.scales[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = scales[i][ + :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features + ] + if cai_linear.bias is not None: + cai_linear.bias[i * cai_split_out_features : (i + 1) * cai_split_out_features] = bias[i][ + tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features + ] + + cai_linear.g_idx.copy_(g_idx) + + +def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1): + qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0) + qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0) + scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0) + g_idxs = gptq_linear.g_idx.split(gptq_linear.in_features // split_num, dim=0) + + cai_split_in_features = cai_linear.infeatures // (32 // cai_linear.bits) // split_num + zero_split_block = cai_linear.infeatures // cai_linear.groupsize // split_num + idx_split_features = cai_linear.infeatures // split_num + + for i in range(split_num): + cai_linear.qweight[i * cai_split_in_features : (i + 1) * cai_split_in_features, :] = qweights[i][ + tp_rank * cai_split_in_features : (tp_rank + 1) * cai_split_in_features, : + ] + cai_linear.qzeros[i * zero_split_block : (i + 1) * zero_split_block, :] = qzeros[i][ + tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, : + ] + cai_linear.scales[i * zero_split_block : (i + 1) * zero_split_block, :] = scales[i][ + tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, : + ] + cai_linear.g_idx[i * idx_split_features : (i + 1) * idx_split_features] = g_idxs[i][ + tp_rank * idx_split_features : (tp_rank + 1) * idx_split_features + ] + if cai_linear.bias is not None: + cai_linear.bias.copy_(gptq_linear.bias) + + +class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): + super().__init__( + bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split + ) + self.process_group = None + + @staticmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + tp_rank = dist.get_rank(process_group) + + if in_features < tp_size: + return module + + if in_features % tp_size != 0: + raise ValueError( + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + linear_1d = RowCaiQuantLinear( + module.bits, + module.group_size, + module.in_features // tp_size, + module.out_features, + module.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=True, + ) + linear_1d.process_group = process_group + + split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) + return linear_1d + + def forward(self, x): + output = super().forward(x) + if self.tp_size > 1: + dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group) + if self.bias is not None: + output.add_(self.bias) + return output + + +class ColCaiQuantLinear(CaiQuantLinear, ParallelModule): + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): + super().__init__( + bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split + ) + self.process_group = None + + @staticmethod + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: + LazyInitContext.materialize(module) + # get the attributes + in_features = module.in_features + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + process_group = process_group[0] + + tp_size = dist.get_world_size(process_group) + tp_rank = dist.get_rank(process_group) + + if in_features < tp_size: + return module + + if in_features % tp_size != 0: + raise ValueError( + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + linear_1d = ColCaiQuantLinear( + module.bits, + module.group_size, + module.in_features, + module.out_features // tp_size, + module.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank, + ) + linear_1d.process_group = process_group + + split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) + return linear_1d diff --git a/colossalai/legacy/inference/quant/gptq/cai_gptq/gptq_op.py b/colossalai/legacy/inference/quant/gptq/cai_gptq/gptq_op.py new file mode 100644 index 000000000000..a8902eb35cd0 --- /dev/null +++ b/colossalai/legacy/inference/quant/gptq/cai_gptq/gptq_op.py @@ -0,0 +1,58 @@ +import torch + +from colossalai.kernel.triton import gptq_fused_linear_triton + + +class CaiGPTQLinearOp(torch.nn.Module): + def __init__(self, gptq_group_size, gptq_quant_bits): + super(CaiGPTQLinearOp, self).__init__() + self.group_size = gptq_group_size + self.bits = gptq_quant_bits + self.maxq = 2**self.bits - 1 + self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device()) + + def forward( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zeros: torch.Tensor, + g_idx: torch.Tensor = None, + act_type=0, + bias: torch.Tensor = None, + residual: torch.Tensor = None, + qkv_fused=False, + ): + add_bias = True + if bias is None: + bias = self.empty_tensor + add_bias = False + + add_residual = True + if residual is None: + residual = self.empty_tensor + add_residual = False + x = input.view(-1, input.shape[-1]) + + out = gptq_fused_linear_triton( + x, + weight, + weight_scales, + weight_zeros, + bias, + residual, + self.bits, + self.maxq, + self.group_size, + qkv_fused, + add_bias, + add_residual, + act_type=act_type, + g_idx=g_idx, + ) + if qkv_fused: + out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1]) + else: + out = out.view(input.shape[0], input.shape[1], weight.shape[-1]) + + return out diff --git a/colossalai/legacy/inference/quant/smoothquant/__init__.py b/colossalai/legacy/inference/quant/smoothquant/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/legacy/inference/quant/smoothquant/models/__init__.py b/colossalai/legacy/inference/quant/smoothquant/models/__init__.py new file mode 100644 index 000000000000..77541d8610c5 --- /dev/null +++ b/colossalai/legacy/inference/quant/smoothquant/models/__init__.py @@ -0,0 +1,12 @@ +try: + import torch_int + + HAS_TORCH_INT = True +except ImportError: + HAS_TORCH_INT = False + raise ImportError( + "Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int" + ) + +if HAS_TORCH_INT: + from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP diff --git a/colossalai/legacy/inference/quant/smoothquant/models/base_model.py b/colossalai/legacy/inference/quant/smoothquant/models/base_model.py new file mode 100644 index 000000000000..9554be9ea96b --- /dev/null +++ b/colossalai/legacy/inference/quant/smoothquant/models/base_model.py @@ -0,0 +1,487 @@ +# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ +# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py +# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py + +import os +import warnings +from abc import abstractmethod +from functools import partial +from os.path import isdir, isfile, join +from typing import Dict, List, Optional, Union + +import accelerate +import numpy as np +import torch +import torch.nn as nn +import transformers +from safetensors.torch import save_file as safe_save +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel +from transformers.modeling_utils import no_init_weights +from transformers.utils.generic import ContextManagers +from transformers.utils.hub import PushToHubMixin, cached_file + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager + +SUPPORTED_MODELS = ["llama"] + + +class BaseSmoothForCausalLM(nn.Module, PushToHubMixin): + layer_type: str = None + + def __init__(self, model: PreTrainedModel, quantized: bool = False): + super().__init__() + + self.model = model + self.model_type = self.model.config.model_type + self._quantized = quantized + self.config = self.model.config + self.cache_manager = None + self.max_total_token_num = 0 + + @property + def quantized(self): + return self._quantized + + def init_cache_manager(self, max_total_token_num=2048): + if self.config.model_type == "llama": + head_num = self.config.num_key_value_heads + layer_num = self.config.num_hidden_layers + head_dim = self.config.hidden_size // head_num + + self.cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num) + self.max_total_token_num = max_total_token_num + + def init_batch_state(self, max_output_len=256, **kwargs): + input_ids = kwargs["input_ids"] + batch_size = len(input_ids) + + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + start_index = 0 + max_len_in_batch = -1 + + for i in range(batch_size): + seq_len = len(input_ids[i]) + seq_lengths[i] = seq_len + seq_start_indexes[i] = start_index + start_index += seq_len + max_len_in_batch = seq_len if seq_len > max_len_in_batch else max_len_in_batch + + if "max_total_token_num" in kwargs.keys(): + max_total_token_num = kwargs["max_total_token_num"] + self.init_cache_manager(max_total_token_num) + + if "max_new_tokens" in kwargs.keys(): + max_output_len = kwargs["max_new_tokens"] + + if batch_size * (max_len_in_batch + max_output_len) > self.max_total_token_num: + max_total_token_num = batch_size * (max_len_in_batch + max_output_len) + warnings.warn(f"reset max tokens to {max_total_token_num}") + self.init_cache_manager(max_total_token_num) + + block_loc = torch.empty((batch_size, max_len_in_batch + max_output_len), dtype=torch.long, device="cuda") + batch_infer_state = BatchInferState(batch_size, max_len_in_batch) + batch_infer_state.seq_len = seq_lengths.to("cuda") + batch_infer_state.start_loc = seq_start_indexes.to("cuda") + batch_infer_state.block_loc = block_loc + batch_infer_state.decode_layer_id = 0 + batch_infer_state.is_context_stage = True + batch_infer_state.set_cache_manager(self.cache_manager) + batch_infer_state.cache_manager.free_all() + return batch_infer_state + + @abstractmethod + @torch.inference_mode() + def quantize( + self, + examples: List[Dict[str, Union[List[int], torch.LongTensor]]], + ): + if self.quantized: + raise EnvironmentError("can't execute quantize because the model is quantized.") + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def generate(self, **kwargs): + """shortcut for model.generate""" + + batch_infer_state = self.init_batch_state(**kwargs) + if self.config.model_type == "llama": + setattr(self.model.model, "infer_state", batch_infer_state) + + with torch.inference_mode(): + return self.model.generate(**kwargs) + + def prepare_inputs_for_generation(self, *args, **kwargs): + """shortcut for model.prepare_inputs_for_generation""" + return self.model.prepare_inputs_for_generation(*args, **kwargs) + + def collect_act_scales(self, model, tokenizer, dataset, device, num_samples=512, seq_len=512): + for text in tqdm(dataset): + input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device) + model(input_ids) + + def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samples=512, seq_len=512): + pbar = tqdm(dataset) + for text in pbar: + input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device) + model(input_ids) + mean_scale = np.mean([v["input"] for v in act_dict.values()]) + pbar.set_description(f"Mean input scale: {mean_scale:.2f}") + + # Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py + def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512): + model.eval() + device = next(model.parameters()).device + act_scales = {} + + def stat_tensor(name, tensor): + hidden_dim = tensor.shape[-1] + tensor = tensor.view(-1, hidden_dim).abs().detach() + comming_max = torch.max(tensor, dim=0)[0].float().cpu() + if name in act_scales: + act_scales[name] = torch.max(act_scales[name], comming_max) + else: + act_scales[name] = comming_max + + def stat_input_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + stat_tensor(name, x) + + hooks = [] + for name, m in model.named_modules(): + if isinstance(m, nn.Linear): + hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name))) + + self.collect_act_scales(model, tokenizer, dataset, device, num_samples, seq_len) + + for h in hooks: + h.remove() + + return act_scales + + # Adapted from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py + @torch.no_grad() + def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5): + if not isinstance(fcs, list): + fcs = [fcs] + for fc in fcs: + assert isinstance(fc, nn.Linear) + assert ln.weight.numel() == fc.in_features == act_scales.numel() + + device, dtype = fcs[0].weight.device, fcs[0].weight.dtype + act_scales = act_scales.to(device=device, dtype=dtype) + weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0) + weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) + + scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype) + + ln.weight.div_(scales) + if hasattr(ln, "bias"): + ln.bias.div_(scales) + + for fc in fcs: + fc.weight.mul_(scales.view(1, -1)) + + @classmethod + def create_quantized_model(model): + raise NotImplementedError("Not implement create_quantized_model method") + + # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py + def save_quantized( + self, + save_dir: str, + model_basename: str, + use_safetensors: bool = False, + safetensors_metadata: Optional[Dict[str, str]] = None, + ): + """save quantized model and configs to local disk""" + os.makedirs(save_dir, exist_ok=True) + + if not self.quantized: + raise EnvironmentError("can only save quantized model, please execute .quantize first.") + + self.model.to("cpu") + + model_base_name = model_basename # or f"smooth-" + if use_safetensors: + model_save_name = model_base_name + ".safetensors" + state_dict = self.model.state_dict() + state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} + if safetensors_metadata is None: + safetensors_metadata = {} + elif not isinstance(safetensors_metadata, dict): + raise TypeError("safetensors_metadata must be a dictionary.") + else: + print(f"Received safetensors_metadata: {safetensors_metadata}") + new_safetensors_metadata = {} + converted_keys = False + for key, value in safetensors_metadata.items(): + if not isinstance(key, str) or not isinstance(value, str): + converted_keys = True + try: + new_key = str(key) + new_value = str(value) + except Exception as e: + raise TypeError( + f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}" + ) + if new_key in new_safetensors_metadata: + print( + f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting." + ) + new_safetensors_metadata[new_key] = new_value + safetensors_metadata = new_safetensors_metadata + if converted_keys: + print( + f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}" + ) + + # Format is required to enable Accelerate to load the metadata + # otherwise it raises an OSError + safetensors_metadata["format"] = "pt" + + safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata) + else: + model_save_name = model_base_name + ".bin" + torch.save(self.model.state_dict(), join(save_dir, model_save_name)) + + self.model.config.save_pretrained(save_dir) + + # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py + def save_pretrained( + self, + save_dir: str, + use_safetensors: bool = False, + safetensors_metadata: Optional[Dict[str, str]] = None, + **kwargs, + ): + """alias of save_quantized""" + warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.") + self.save_quantized(save_dir, use_safetensors, safetensors_metadata) + + # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + max_memory: Optional[dict] = None, + trust_remote_code: bool = False, + torch_dtype: torch.dtype = torch.float16, + **model_init_kwargs, + ): + if not torch.cuda.is_available(): + raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.") + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + + # Parameters related to loading from Hugging Face Hub + cache_dir = model_init_kwargs.pop("cache_dir", None) + force_download = model_init_kwargs.pop("force_download", False) + resume_download = model_init_kwargs.pop("resume_download", False) + proxies = model_init_kwargs.pop("proxies", None) + local_files_only = model_init_kwargs.pop("local_files_only", False) + use_auth_token = model_init_kwargs.pop("use_auth_token", None) + revision = model_init_kwargs.pop("revision", None) + subfolder = model_init_kwargs.pop("subfolder", "") + model_init_kwargs.pop("_commit_hash", None) + + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "revision": revision, + "subfolder": subfolder, + } + + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, **cached_file_kwargs) + if config.model_type not in SUPPORTED_MODELS: + raise TypeError(f"{config.model_type} isn't supported yet.") + + # enforce some values despite user specified + model_init_kwargs["torch_dtype"] = torch_dtype + model_init_kwargs["trust_remote_code"] = trust_remote_code + if max_memory: + if "disk" in max_memory: + raise NotImplementedError("disk offload not support yet.") + with accelerate.init_empty_weights(): + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + model.tie_weights() + + max_memory = accelerate.utils.get_balanced_memory( + model, + max_memory=max_memory, + no_split_module_classes=[cls.layer_type], + dtype=model_init_kwargs["torch_dtype"], + low_zero=False, + ) + model_init_kwargs["device_map"] = accelerate.infer_auto_device_map( + model, + max_memory=max_memory, + no_split_module_classes=[cls.layer_type], + dtype=model_init_kwargs["torch_dtype"], + ) + model_init_kwargs["low_cpu_mem_usage"] = True + + del model + else: + model_init_kwargs["device_map"] = None + model_init_kwargs["low_cpu_mem_usage"] = False + + torch.cuda.empty_cache() + + merged_kwargs = {**model_init_kwargs, **cached_file_kwargs} + model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs) + + model_config = model.config.to_dict() + seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] + if any([k in model_config for k in seq_len_keys]): + for key in seq_len_keys: + if key in model_config: + model.seqlen = model_config[key] + break + else: + warnings.warn("can't get model's sequence length from model config, will set to 4096.") + model.seqlen = 4096 + model.eval() + + return cls(model, False) + + # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py + @classmethod + def from_quantized( + cls, + model_name_or_path: Optional[str], + model_basename: Optional[str] = None, + device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None, + max_memory: Optional[dict] = None, + device: Optional[Union[str, int]] = None, + low_cpu_mem_usage: bool = False, + torch_dtype: Optional[torch.dtype] = None, + use_safetensors: bool = False, + trust_remote_code: bool = False, + **kwargs, + ): + """load quantized model from local disk""" + + # Parameters related to loading from Hugging Face Hub + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + + # == step1: prepare configs and file names == # + config = AutoConfig.from_pretrained( + model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs + ) + + if config.model_type not in SUPPORTED_MODELS: + raise TypeError(f"{config.model_type} isn't supported yet.") + + extensions = [] + if use_safetensors: + extensions.append(".safetensors") + else: + extensions += [".bin", ".pt"] + + model_name_or_path = str(model_name_or_path) + is_local = isdir(model_name_or_path) + + resolved_archive_file = None + if is_local: + model_save_name = join(model_name_or_path, model_basename) + for ext in extensions: + if isfile(model_save_name + ext): + resolved_archive_file = model_save_name + ext + break + else: # remote + for ext in extensions: + resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs) + if resolved_archive_file is not None: + break + + if resolved_archive_file is None: # Could not find a model file to use + raise FileNotFoundError(f"Could not find model in {model_name_or_path}") + + model_save_name = resolved_archive_file + + # == step2: convert model to quantized-model (replace Linear) == # + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + + transformers.modeling_utils._init_weights = False + + init_contexts = [no_init_weights()] + if low_cpu_mem_usage: + init_contexts.append(accelerate.init_empty_weights(include_buffers=True)) + + with ContextManagers(init_contexts): + model = AutoModelForCausalLM.from_config( + config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype + ) + cls.create_quantized_model(model) + model.tie_weights() + + # == step3: load checkpoint to quantized-model == # + accelerate.utils.modeling.load_checkpoint_in_model( + model, checkpoint=model_save_name, offload_state_dict=True, offload_buffers=True + ) + + # == step4: set seqlen == # + model_config = model.config.to_dict() + seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] + if any([k in model_config for k in seq_len_keys]): + for key in seq_len_keys: + if key in model_config: + model.seqlen = model_config[key] + break + else: + warnings.warn("can't get model's sequence length from model config, will set to 4096.") + model.seqlen = 4096 + + return cls( + model, + True, + ) + + def __getattr__(self, item): + try: + return super().__getattr__(item) + except: + return getattr(self.model, item) + + +__all__ = ["BaseSmoothForCausalLM"] diff --git a/colossalai/legacy/inference/quant/smoothquant/models/linear.py b/colossalai/legacy/inference/quant/smoothquant/models/linear.py new file mode 100644 index 000000000000..969c390a0849 --- /dev/null +++ b/colossalai/legacy/inference/quant/smoothquant/models/linear.py @@ -0,0 +1,179 @@ +# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py + +import torch +from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32 +from torch_int.functional.quantization import quantize_per_tensor_absmax + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except ImportError: + HAS_SMOOTHQUANT_CUDA = False + raise ImportError("CUDA smoothquant linear is not installed") + + +class W8A8BFP32O32LinearSiLU(torch.nn.Module): + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False), + ) + self.register_buffer("a", torch.tensor(alpha)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight + if module.bias is not None: + int8_module.bias.data.copy_(module.bias.to(torch.float)) + int8_module.a = alpha + return int8_module + + +# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py +class W8A8B8O8Linear(torch.nn.Module): + # For qkv_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False), + ) + self.register_buffer("a", torch.tensor(alpha)) + self.register_buffer("b", torch.tensor(beta)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item()) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale, output_scale): + int8_module = W8A8B8O8Linear(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale / output_scale + int8_module.weight = int8_weight + int8_module.a = alpha + + if module.bias is not None: + int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias) + int8_module.bias = int8_bias + beta = bias_scale / output_scale + int8_module.b = beta + + return int8_module + + +# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py +class W8A8BFP32OFP32Linear(torch.nn.Module): + # For fc2 and out_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False), + ) + self.register_buffer("a", torch.tensor(alpha)) + + def _apply(self, fn): + # prevent the bias from being converted to half + super()._apply(fn) + self.bias = self.bias.to(torch.float32) + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + self.bias = self.bias.to(torch.float32) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = linear_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32OFP32Linear(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight + int8_module.a = alpha + int8_module.input_scale = input_scale + int8_module.weight_scale = weight_scale + + if module.bias is not None: + int8_module.bias = module.bias.to(torch.float32) + + return int8_module diff --git a/colossalai/legacy/inference/quant/smoothquant/models/llama.py b/colossalai/legacy/inference/quant/smoothquant/models/llama.py new file mode 100644 index 000000000000..30063857ac30 --- /dev/null +++ b/colossalai/legacy/inference/quant/smoothquant/models/llama.py @@ -0,0 +1,838 @@ +import math +import os +import types +from collections import defaultdict +from functools import partial +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T +from transformers import PreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LLAMA_INPUTS_DOCSTRING, + LlamaAttention, + LlamaDecoderLayer, + LlamaMLP, + LlamaRotaryEmbedding, + repeat_kv, + rotate_half, +) +from transformers.utils import add_start_docstrings_to_model_forward + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton import ( + copy_kv_cache_to_dest, + int8_rotary_embedding_fwd, + smooth_llama_context_attn_fwd, + smooth_token_attention_fwd, +) + +from .base_model import BaseSmoothForCausalLM +from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear + + +class LLamaSmoothquantAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + if (self.head_dim * num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {num_heads})." + ) + + self.qk_bmm = BMM_S8T_S8N_F32T(1.0) + self.pv_bmm = BMM_S8T_S8N_S8T(1.0) + + self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.o_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size) + + self.register_buffer("q_output_scale", torch.tensor([1.0])) + self.register_buffer("k_output_scale", torch.tensor([1.0])) + self.register_buffer("v_output_scale", torch.tensor([1.0])) + self.register_buffer("q_rotary_output_scale", torch.tensor([1.0])) + self.register_buffer("k_rotary_output_scale", torch.tensor([1.0])) + self.register_buffer("out_input_scale", torch.tensor([1.0])) + self.register_buffer("attn_input_scale", torch.tensor([1.0])) + + self._init_rope() + self.num_key_value_heads = num_heads + + def _init_rope(self): + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=2048, + base=10000.0, + ) + + @staticmethod + def pack( + module: LlamaAttention, + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + q_rotary_output_scale: float, + k_rotary_output_scale: float, + out_input_scale: float, + ): + int8_module = LLamaSmoothquantAttention(module.hidden_size, module.num_heads) + + int8_module.attn_input_scale = torch.tensor([attn_input_scale]) + + int8_module.q_output_scale = torch.tensor([q_output_scale]) + int8_module.k_output_scale = torch.tensor([k_output_scale]) + int8_module.v_output_scale = torch.tensor([v_output_scale]) + + int8_module.q_rotary_output_scale = torch.tensor([q_rotary_output_scale]) + int8_module.k_rotary_output_scale = torch.tensor([k_rotary_output_scale]) + + int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, attn_input_scale, q_output_scale) + int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, attn_input_scale, k_output_scale) + int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale) + int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale) + + int8_module.out_input_scale = torch.tensor([out_input_scale]) + + return int8_module + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + @torch.no_grad() + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: Tuple[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + cos = rotary_emb[0] + sin = rotary_emb[1] + + int8_rotary_embedding_fwd( + query_states.view(-1, self.num_heads, self.head_dim), + cos, + sin, + self.q_output_scale.item(), + self.q_rotary_output_scale.item(), + ) + int8_rotary_embedding_fwd( + key_states.view(-1, self.num_heads, self.head_dim), + cos, + sin, + self.k_output_scale.item(), + self.k_rotary_output_scale.item(), + ) + + def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + return + + query_states = query_states.view(-1, self.num_heads, self.head_dim) + key_states = key_states.view(-1, self.num_heads, self.head_dim) + value_states = value_states.view(-1, self.num_heads, self.head_dim) + + if infer_state.is_context_stage: + # first token generation + + # copy key and value calculated in current step to memory manager + _copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.context_mem_index, + infer_state.cache_manager, + ) + + attn_output = torch.empty_like(query_states) + + smooth_llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + self.q_rotary_output_scale.item(), + self.k_rotary_output_scale.item(), + self.v_output_scale.item(), + self.out_input_scale.item(), + infer_state.start_loc, + infer_state.seq_len, + q_len, + ) + + else: + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(key_states) + cache_v.copy_(value_states) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + _copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) + + # (batch_size, seqlen, nheads, headdim) + attn_output = torch.empty_like(query_states) + + smooth_token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + self.q_rotary_output_scale.item(), + self.k_rotary_output_scale.item(), + self.v_output_scale.item(), + self.out_input_scale.item(), + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.max_len_in_batch, + ) + + attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output, None, None + + +class LlamaLayerNormQ(torch.nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.input_scale = 1.0 + self.variance_epsilon = eps + self.register_buffer("weight", torch.ones(dim, dtype=torch.float32)) + + def forward(self, x): + ln_output_fp = torch.nn.functional.layer_norm(x, x.shape[-1:], self.weight, None, self.variance_epsilon) + ln_output_int8 = ln_output_fp.round().clamp(-128, 127).to(torch.int8) + return ln_output_int8 + + @staticmethod + def from_float(module: torch.nn.LayerNorm, output_scale: float): + assert module.weight.shape[0] == module.weight.numel() + q_module = LlamaLayerNormQ(module.weight.shape[0], module.variance_epsilon) + q_module.weight = module.weight / output_scale + return q_module + + +class LlamaSmoothquantMLP(nn.Module): + def __init__(self, intermediate_size, hidden_size): + super().__init__() + self.gate_proj = W8A8BFP32O32LinearSiLU(hidden_size, intermediate_size) + self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size) + self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size) + self.register_buffer("down_proj_input_scale", torch.tensor([1.0])) + + @staticmethod + def pack( + mlp_module: LlamaMLP, + gate_proj_input_scale: float, + up_proj_input_scale: float, + down_proj_input_scale: float, + ): + int8_module = LlamaSmoothquantMLP( + mlp_module.intermediate_size, + mlp_module.hidden_size, + ) + + int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale) + int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale) + int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale) + int8_module.down_proj_input_scale = torch.tensor([down_proj_input_scale]) + return int8_module + + def forward( + self, + hidden_states: torch.Tensor, + ): + x_shape = hidden_states.shape + gate_out = self.gate_proj(hidden_states) + up_out = self.up_proj(hidden_states) + inter_out = gate_out * up_out + inter_out = inter_out.div_(self.down_proj_input_scale.item()).round().clamp(-128, 127).to(torch.int8) + down_out = self.down_proj(inter_out) + down_out = down_out.view(*x_shape[:-1], -1) + return down_out + + +class LlamaSmoothquantDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LLamaSmoothquantAttention(config.hidden_size, config.num_attention_heads) + + self.mlp = LlamaSmoothquantMLP(config.intermediate_size, config.hidden_size) + self.input_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps) + + self.post_attention_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps) + + @staticmethod + def pack( + module: LlamaDecoderLayer, + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + q_rotary_output_scale: float, + k_rotary_output_scale: float, + out_input_scale: float, + gate_input_scale: float, + up_input_scale: float, + down_input_scale: float, + ): + config = module.self_attn.config + int8_decoder_layer = LlamaSmoothquantDecoderLayer(config) + + int8_decoder_layer.input_layernorm = LlamaLayerNormQ.from_float(module.input_layernorm, attn_input_scale) + int8_decoder_layer.self_attn = LLamaSmoothquantAttention.pack( + module.self_attn, + attn_input_scale, + q_output_scale, + k_output_scale, + v_output_scale, + q_rotary_output_scale, + k_rotary_output_scale, + out_input_scale, + ) + + int8_decoder_layer.post_attention_layernorm = LlamaLayerNormQ.from_float( + module.post_attention_layernorm, gate_input_scale + ) + + int8_decoder_layer.mlp = LlamaSmoothquantMLP.pack( + module.mlp, + gate_input_scale, + up_input_scale, + down_input_scale, + ) + + return int8_decoder_layer + + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: Tuple[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + rotary_emb=rotary_emb, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + infer_state=infer_state, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, None, None + + +class LlamaApplyRotary(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + x_embed = (x * cos) + (rotate_half(x) * sin) + + return x_embed + + +# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +def llama_decoder_layer_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states = self.q_apply_rotary(query_states, cos, sin, position_ids) + key_states = self.k_apply_rotary(key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def init_to_get_rotary(config, base=10000, use_elem=False): + """ + This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer + Args: + base : calculation arg + use_elem : activated when using chatglm-based models + """ + config.head_dim_ = config.hidden_size // config.num_attention_heads + if not hasattr(config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = config.rope_scaling.factor if config.rope_scaling is not None else 1.0 + + if hasattr(config, "max_sequence_length"): + max_seq_len = config.max_sequence_length + elif hasattr(config, "max_position_embeddings"): + max_seq_len = config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + try: + ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1)) + assert ntk_alpha >= 1 + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (config.head_dim_ / (config.head_dim_ - 2))) # Base change formula + except: + pass + + n_elem = config.head_dim_ + if use_elem: + n_elem //= 2 + + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + _cos_cached = torch.cos(freqs).to(torch.float) + _sin_cached = torch.sin(freqs).to(torch.float) + return _cos_cached, _sin_cached + + +# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) +def llama_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + infer_state = self.infer_state + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 + + seq_length_with_past = seq_length + past_key_values_length + + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + if infer_state.is_context_stage: + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print(f" infer_state.cache_manager.max_len_in_batch: {infer_state.max_len_in_batch}") + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device) + padding_mask = None + else: + if 0 in attention_mask: + padding_mask = attention_mask + else: + padding_mask = None + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + raise NotImplementedError("not implement gradient_checkpointing and training options ") + + if past_key_values_length == 0: + position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + else: + position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1) + position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + infer_state.decode_layer_id = 0 + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + rotary_emb=(position_cos, position_sin), + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + infer_state=infer_state, + ) + + hidden_states = layer_outputs[0] + infer_state.decode_layer_id += 1 + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + infer_state.is_context_stage = False + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + infer_state.max_len_in_batch += 1 + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class SmoothLlamaForCausalLM(BaseSmoothForCausalLM): + layer_type = "LlamaDecoderLayer" + + def __init__(self, model: PreTrainedModel, quantized: bool = False): + super().__init__(model, quantized) + + # Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py + def get_act_dict( + self, + tokenizer, + dataset, + num_samples=512, + seq_len=512, + ): + llama_model = self.model + + llama_model.eval() + device = next(llama_model.parameters()).device + # print("model:", llama_model) + act_dict = defaultdict(dict) + + def stat_io_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + if name not in act_dict or "input" not in act_dict[name]: + act_dict[name]["input"] = x.detach().abs().max().item() + else: + act_dict[name]["input"] = max(act_dict[name]["input"], x.detach().abs().max().item()) + if isinstance(y, tuple): + y = y[0] + if name not in act_dict or "output" not in act_dict[name]: + act_dict[name]["output"] = y.detach().abs().max().item() + else: + act_dict[name]["output"] = max(act_dict[name]["output"], y.detach().abs().max().item()) + + for name, m in llama_model.named_modules(): + if isinstance(m, LlamaAttention): + setattr(m, "q_apply_rotary", LlamaApplyRotary()) + setattr(m, "k_apply_rotary", LlamaApplyRotary()) + m.forward = types.MethodType(llama_decoder_layer_forward, m) + + hooks = [] + for name, m in llama_model.named_modules(): + if isinstance(m, LlamaApplyRotary): + hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) + if isinstance(m, torch.nn.Linear): + hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) + + self.collect_act_dict(llama_model, tokenizer, dataset, act_dict, device, num_samples, seq_len) + + for hook in hooks: + hook.remove() + return act_dict + + def smooth_fn(self, scales, alpha=0.5): + model = self.model + for name, module in model.named_modules(): + if isinstance(module, LlamaDecoderLayer): + attn_ln = module.input_layernorm + qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj] + qkv_input_scales = scales[name + ".self_attn.q_proj"] + self.smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) + + def create_quantized_model(model): + llama_config = model.config + for i, layer in enumerate(model.model.layers): + model.model.layers[i] = LlamaSmoothquantDecoderLayer(llama_config) + + model.model.forward = types.MethodType(llama_model_forward, model.model) + cos, sin = init_to_get_rotary(llama_config) + model.model.register_buffer("_cos_cached", cos) + model.model.register_buffer("_sin_cached", sin) + + def quantized( + self, + tokenizer, + dataset, + num_samples=512, + seq_len=512, + alpha=0.5, + ): + llama_model = self.model + llama_config = llama_model.config + + act_scales = self.get_act_scales(llama_model, tokenizer, dataset, num_samples, seq_len) + + self.smooth_fn(act_scales, alpha) + + act_dict = self.get_act_dict(tokenizer, dataset, num_samples, seq_len) + decoder_layer_scales = [] + + for idx in range(llama_config.num_hidden_layers): + scale_dict = {} + scale_dict["attn_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["input"] / 127 + scale_dict["q_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["output"] / 127 + scale_dict["k_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_proj"]["output"] / 127 + scale_dict["v_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.v_proj"]["output"] / 127 + + scale_dict["q_rotary_output_scale"] = ( + act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127 + ) + scale_dict["k_rotary_output_scale"] = ( + act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127 + ) + + scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127 + + scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127 + scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127 + scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 + + decoder_layer_scales.append(scale_dict) + + for i, layer in enumerate(llama_model.model.layers): + orig_layer = layer + llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer.pack(orig_layer, **decoder_layer_scales[i]) + + llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model) + + cos, sin = init_to_get_rotary(llama_config) + llama_model.model.register_buffer("_cos_cached", cos.to(self.model.device)) + llama_model.model.register_buffer("_sin_cached", sin.to(self.model.device)) diff --git a/colossalai/legacy/inference/tensor_parallel/__init__.py b/colossalai/legacy/inference/tensor_parallel/__init__.py new file mode 100644 index 000000000000..112b920ba158 --- /dev/null +++ b/colossalai/legacy/inference/tensor_parallel/__init__.py @@ -0,0 +1,4 @@ +from .engine import TPInferEngine +from .kvcache_manager import MemoryManager + +__all__ = ["MemoryManager", "TPInferEngine"] diff --git a/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py b/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py new file mode 100644 index 000000000000..f707a86df37e --- /dev/null +++ b/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py @@ -0,0 +1,118 @@ +# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later +from dataclasses import dataclass + +import torch +from transformers.tokenization_utils_base import BatchEncoding + +from .kvcache_manager import MemoryManager + + +# adapted from: lightllm/server/router/model_infer/infer_batch.py +@dataclass +class BatchInferState: + r""" + Information to be passed and used for a batch of inputs during + a single model forward + """ + batch_size: int + max_len_in_batch: int + + cache_manager: MemoryManager = None + + block_loc: torch.Tensor = None + start_loc: torch.Tensor = None + seq_len: torch.Tensor = None + past_key_values_len: int = None + + is_context_stage: bool = False + context_mem_index: torch.Tensor = None + decode_is_contiguous: bool = None + decode_mem_start: int = None + decode_mem_end: int = None + decode_mem_index: torch.Tensor = None + decode_layer_id: int = None + + device: torch.device = torch.device("cuda") + + @property + def total_token_num(self): + # return self.batch_size * self.max_len_in_batch + assert self.seq_len is not None and self.seq_len.size(0) > 0 + return int(torch.sum(self.seq_len)) + + def set_cache_manager(self, manager: MemoryManager): + self.cache_manager = manager + + # adapted from: https://github.com/ModelTC/lightllm/blob/28c1267cfca536b7b4f28e921e03de735b003039/lightllm/common/infer_utils.py#L1 + @staticmethod + def init_block_loc( + b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor + ): + """in-place update block loc mapping based on the sequence length of the inputs in current bath""" + start_index = 0 + seq_len_numpy = seq_len.cpu().numpy() + for i, cur_seq_len in enumerate(seq_len_numpy): + b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[ + start_index : start_index + cur_seq_len + ] + start_index += cur_seq_len + return + + @classmethod + def init_from_batch( + cls, + batch: torch.Tensor, + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, + ): + if not isinstance(batch, (BatchEncoding, dict, list, torch.Tensor)): + raise TypeError(f"batch type {type(batch)} is not supported in prepare_batch_state") + + input_ids_list = None + attention_mask = None + + if isinstance(batch, (BatchEncoding, dict)): + input_ids_list = batch["input_ids"] + attention_mask = batch["attention_mask"] + else: + input_ids_list = batch + if isinstance(input_ids_list[0], int): # for a single input + input_ids_list = [input_ids_list] + attention_mask = [attention_mask] if attention_mask is not None else attention_mask + + batch_size = len(input_ids_list) + + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + start_index = 0 + + max_len_in_batch = -1 + if isinstance(batch, (BatchEncoding, dict)): + for i, attn_mask in enumerate(attention_mask): + curr_seq_len = len(attn_mask) + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + else: + length = max(len(input_id) for input_id in input_ids_list) + for i, input_ids in enumerate(input_ids_list): + curr_seq_len = length + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + block_loc = torch.zeros((batch_size, max_input_len + max_output_len), dtype=torch.long, device="cuda") + + return cls( + batch_size=batch_size, + max_len_in_batch=max_len_in_batch, + seq_len=seq_lengths.to("cuda"), + start_loc=seq_start_indexes.to("cuda"), + block_loc=block_loc, + decode_layer_id=0, + past_key_values_len=0, + is_context_stage=True, + cache_manager=cache_manager, + ) diff --git a/colossalai/legacy/inference/tensor_parallel/engine.py b/colossalai/legacy/inference/tensor_parallel/engine.py new file mode 100644 index 000000000000..2478b574d307 --- /dev/null +++ b/colossalai/legacy/inference/tensor_parallel/engine.py @@ -0,0 +1,480 @@ +from typing import Any, Callable, List, Optional, Union + +import torch +import torch.nn as nn +from transformers import BloomForCausalLM, LlamaForCausalLM +from transformers.generation import GenerationConfig +from transformers.generation.stopping_criteria import StoppingCriteriaList +from transformers.tokenization_utils_base import BatchEncoding + +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.auto_policy import get_autopolicy + +from .batch_infer_state import BatchInferState +from .kvcache_manager import MemoryManager + +# from dynamic_batching.infer_batch import InferBatch + +DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 + +_supported_models = [ + "LlamaForCausalLM", + "LlamaModel", + "BloomForCausalLM", + "ChatGLMModel", + "ChatGLMForConditionalGeneration", + "LlamaGPTQForCausalLM", + "BloomGPTQForCausalLM", +] + + +class TPInferEngine: + """Engine class for tensor parallel inference. + + Args: + model (Module): original model, e.g. huggingface CausalLM + shard_config (ShardConfig): The config for sharding original model + max_batch_size (int): maximum batch size + max_input_len (int): maximum input length of sequence + max_output_len (int): maximum output length of output tokens + dtype (torch.dtype): datatype used to init KV cache space + device (str): device the KV cache of engine to be initialized on + + Examples: + >>> # define model and shard config for your inference + >>> model = ... + >>> generate_kwargs = ... + >>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) + >>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + >>> outputs = infer_engine.generate(input_ids, **generate_kwargs) + """ + + def __init__( + self, + model: nn.Module, + shard_config: ShardConfig, + max_batch_size: int, + max_input_len: int, + max_output_len: int, + dtype: torch.dtype = torch.float16, + device: str = "cuda", + ) -> None: + self.max_batch_size = max_batch_size + self.max_input_len = max_input_len + self.max_output_len = max_output_len + self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len) + # Constraints relatable with specs of devices and model + # This may change into an optional arg in the future + assert self.max_batch_size <= 64, "Max batch size exceeds the constraint" + assert self.max_input_len + self.max_output_len <= 4096, "Max length exceeds the constraint" + + self.dtype = dtype + + self.head_dim = model.config.hidden_size // model.config.num_attention_heads + self.head_num = model.config.num_attention_heads + num_hidden_layers = ( + model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers + ) + self.layer_num = num_hidden_layers + + self.multi_query_group_num = model.config.num_attention_heads + # default to attention_heads + if hasattr(model.config, "multi_query_attention"): + self.multi_query_attention = getattr(model.config, "multi_query_attention") + + if hasattr(model.config, "multi_query_group_num"): + self.multi_query_group_num = getattr(model.config, "multi_query_group_num") + + if hasattr(model.config, "num_key_value_heads"): + self.multi_query_group_num = getattr(model.config, "num_key_value_heads") + + self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config + self.cache_manager = None + + self.max_dq_buffer_size = 1 + self.max_inner_outer_dim = 1 + self.gptq_temp_state_buffer = None + self.gptq_temp_dq_buffer = None + self.bits = -1 + self.use_act_order = False + + self.shard_config = shard_config + self.model = None + self.cache = {} + + # optimize the original model by sharding with ShardFormer + self._optimize_model(model=model.to(device)) + + def _init_manager(self) -> None: + assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig" + assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" + self.head_num //= self.tp_size # update sharded number of heads + + if hasattr(self, "multi_query_attention"): + # NOTE the logic of MQA tensor parallelism should be specified. + assert ( + self.multi_query_group_num % self.tp_size == 0 + ), f"Cannot shard {self.multi_query_group_num} query groups with tp size {self.tp_size}" + self.cache_manager = MemoryManager( + self.max_total_token_num, + self.dtype, + self.multi_query_group_num // self.tp_size, + self.head_dim, + self.layer_num, + ) + else: + self.cache_manager = MemoryManager( + self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num + ) + + def _post_init_gptq_buffer(self, model: nn.Module) -> None: + from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear + + HAS_GPTQ_CUDA = False + try: + from colossalai.kernel.op_builder.gptq import GPTQBuilder + + gptq_cuda = GPTQBuilder().load() + HAS_GPTQ_CUDA = True + except ImportError: + warnings.warn("CUDA gptq is not installed") + HAS_GPTQ_CUDA = False + + for name, submodule in model.named_modules(): + if isinstance(submodule, CaiQuantLinear): + self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8) + + if self.use_act_order: + self.max_inner_outer_dim = max( + self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures + ) + self.bits = submodule.bits + if not (HAS_GPTQ_CUDA and self.bits == 4): + return + + max_input_len = 1 + if self.use_act_order: + max_input_len = self.max_input_len + # The temp_state buffer is required to reorder X in the act-order case. + # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. + self.gptq_temp_state_buffer = torch.zeros( + (max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device() + ) + self.gptq_temp_dq_buffer = torch.zeros( + (1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device() + ) + + gptq_cuda.prepare_buffers( + torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer + ) + # Using the default from exllama repo here. + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + torch.cuda.empty_cache() + + def _optimize_model(self, model: nn.Module) -> None: + """ + Optimize the original model by sharding with ShardFormer. + In further generation, use the sharded model instead of original model. + """ + # NOTE we will change to use an inference config later with additional attrs we want + assert self.shard_config.inference_only is True + shardformer = ShardFormer(shard_config=self.shard_config) + self._prepare_with_shard_config(shard_config=self.shard_config) + self._shard_model_by(shardformer, model) + + def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig: + """Prepare the engine with a given ShardConfig. + + Args: + shard_config (ShardConfig): shard config given to specify settings of the engine. + If not provided, a default ShardConfig with tp size 1 will be created. + """ + self.tp_size = 1 + if shard_config is None: + shard_config = ShardConfig( + tensor_parallel_process_group=None, + pipeline_stage_manager=None, + enable_tensor_parallelism=False, + enable_fused_normalization=False, + enable_all_optimization=False, + enable_flash_attention=False, + enable_jit_fused=False, + inference_only=True, + ) + else: + shard_config.inference_only = True + shard_config.pipeline_stage_manager = None + if shard_config.enable_tensor_parallelism: + self.tp_size = shard_config.tensor_parallel_size + self._init_manager() + + return shard_config + + def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: + """Shard original model by the given ShardFormer and store the sharded model.""" + assert ( + self.tp_size == shardformer.shard_config.tensor_parallel_size + ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" + model_name = model.__class__.__name__ + assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." + + model = model.model if self.shard_config.inference_gptq else model + policy = get_autopolicy(model, shard_config=self.shard_config) + + self.model, _ = shardformer.optimize(model, policy) + + if self.shard_config.inference_gptq: + self._post_init_gptq_buffer(self.model) + + self.model = self.model.cuda() + + @property + def supported_models(self) -> List[str]: + return _supported_models + + def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], **generate_kwargs) -> torch.Tensor: + """Generate token sequence. + + Args: + input_tokens: could be one of the following types + 1. BatchEncoding or dict (e.g. tokenizer batch_encode) + 2. list of input token ids (e.g. appended result of tokenizer encode) + 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') + Returns: + torch.Tensor: The returned sequence is given inputs + generated_tokens. + """ + if isinstance(input_tokens, torch.Tensor): + input_tokens = dict(input_ids=input_tokens, attention_mask=torch.ones_like(input_tokens, dtype=torch.bool)) + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].cuda() + if "max_new_tokens" not in generate_kwargs: + generate_kwargs.update(max_new_tokens=self.max_output_len) + + return self._generate_by_set_infer_state(input_tokens, **generate_kwargs) + + def prepare_batch_state(self, inputs) -> BatchInferState: + """ + Create and prepare BatchInferState used for inference during model forwrad, + by processing each sequence of the given inputs. + + Args: + inputs: should be one of the following types + 1. BatchEncoding or dict (e.g. tokenizer batch_encode) + 2. list of input token ids (e.g. appended result of tokenizer encode) + 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') + NOTE For torch.Tensor inputs representing a batch of inputs, we are unable to retrieve + the actual length (e.g. number of tokens) of each input without attention mask + Hence, for torch.Tensor with shape [bs, l] where bs > 1, we will assume + all the inputs in the batch has the maximum length l + Returns: + BatchInferState: the states for the current batch during inference + """ + if not isinstance(inputs, (BatchEncoding, dict, list, torch.Tensor)): + raise TypeError(f"inputs type {type(inputs)} is not supported in prepare_batch_state") + + input_ids_list = None + attention_mask = None + + if isinstance(inputs, (BatchEncoding, dict)): + input_ids_list = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + else: + input_ids_list = inputs + if isinstance(input_ids_list[0], int): # for a single input + input_ids_list = [input_ids_list] + attention_mask = [attention_mask] if attention_mask is not None else attention_mask + + batch_size = len(input_ids_list) + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + start_index = 0 + + max_len_in_batch = -1 + if isinstance(inputs, (BatchEncoding, dict)): + for i, attn_mask in enumerate(attention_mask): + curr_seq_len = len(attn_mask) + # if isinstance(attn_mask, torch.Tensor): + # curr_seq_len = int(torch.sum(attn_mask)) + # else: + # curr_seq_len = int(sum(attn_mask)) + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + else: + length = max(len(input_id) for input_id in input_ids_list) + for i, input_ids in enumerate(input_ids_list): + curr_seq_len = length + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + + block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda") + batch_infer_state = BatchInferState(batch_size, max_len_in_batch) + batch_infer_state.seq_len = seq_lengths.to("cuda") + batch_infer_state.start_loc = seq_start_indexes.to("cuda") + batch_infer_state.block_loc = block_loc + batch_infer_state.decode_layer_id = 0 + batch_infer_state.past_key_values_len = 0 + batch_infer_state.is_context_stage = True + batch_infer_state.set_cache_manager(self.cache_manager) + + return batch_infer_state + + @torch.no_grad() + def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch.Tensor: + """ + Generate output tokens by setting BatchInferState as an attribute to the model and calling model.generate + + Args: + inputs: should be one of the following types + 1. BatchEncoding or dict (e.g. tokenizer batch_encode) + 2. list of input token ids (e.g. appended result of tokenizer encode) + 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') + """ + + # for testing, always use sharded model + assert self.model is not None, "sharded model does not exist" + + batch_infer_state = self.prepare_batch_state(input_tokens) + assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit" + + # set BatchInferState for the current batch as attr to model + # NOTE this is not a preferable way to pass BatchInferState during inference + # we might want to rewrite generate function (e.g. _generate_by_pass_infer_state) + # and pass BatchInferState via model forward + model = self.model + if isinstance(model, LlamaForCausalLM): + model = self.model.model + elif isinstance(model, BloomForCausalLM): + model = self.model.transformer + setattr(model, "infer_state", batch_infer_state) + + outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False) + + # NOTE In future development, we're going to let the scheduler to handle the cache, + # instead of freeing space explicitly at the end of generation + self.cache_manager.free_all() + + return outputs + + # TODO might want to implement the func that generates output tokens by passing BatchInferState + # as an arg into model.forward. + # It requires rewriting model generate and replacing model forward. + @torch.no_grad() + def _generate_by_pass_infer_state( + self, + input_tokens, + max_out_length: int, + generation_config: Optional[GenerationConfig] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + **model_kwargs, + ) -> torch.Tensor: + raise NotImplementedError("generate by passing BatchInferState is not implemented.") + + # might want to use in rewritten generate method: use after model.forward + # BatchInferState is created and kept during generation + # after each iter of model forward, we should update BatchInferState + def _update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: + batch_size = infer_state.batch_size + device = infer_state.start_loc.device + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device) + infer_state.seq_len += 1 + + @torch.no_grad() + def forward(self, batch_id, is_prefill): + """ + Forward is used in Dynamic Batching Manager + """ + batch = self.cache.pop(batch_id) + if is_prefill: + input_ = torch.tensor(batch.all_input_ids).cuda() + else: + input_ = batch.input_ids.reshape(len(batch), 1) + + batch_args = { + "batch_size": len(batch), + "max_len_in_batch": batch.nopad_max_len_in_batch, + "block_loc": batch.nopad_b_loc, + "start_loc": batch.nopad_b_start_loc, + "seq_len": batch.nopad_b_seq_len, + "cache_manager": batch.cache_manager, + "is_context_stage": is_prefill, + } + + infer_state = BatchInferState(**batch_args) + model = self.model + if isinstance(model, LlamaForCausalLM): + model = self.model.model + elif isinstance(model, BloomForCausalLM): + model = self.model.transformer + + setattr(model, "infer_state", infer_state) + output = self.model.forward(input_ids=input_) + logits = output.logits + # bsz, seq_len, vocab_size + prob_out = torch.softmax( + logits[ + :, + -1, + ], + dim=-1, + ).squeeze(1) + # prob_out: bsz, vocab_size + predict_ids = torch.argmax(prob_out, dim=-1, keepdim=True) + prob_out = torch.log(prob_out).detach().cpu().numpy() + predict_ids = predict_ids.detach().cpu().numpy() + # [ batch_size, 1 ] + + output_dict = {} + new_input_ids = [] + for i, (r, all_input_ids, next_token_id, next_token_logprob) in enumerate( + zip(batch.requests, batch.all_input_ids, predict_ids, prob_out) + ): + next_token_id = int(next_token_id) + next_token_logprob = next_token_logprob[next_token_id] + # all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device="cuda") + all_input_ids.append(next_token_id) + # all_input_ids_tensor = None + new_input_ids.append(next_token_id) + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] += 1 + batch.out_token_id_counts[i][next_token_id] += 1 + metadata = { + "id": int(next_token_id), + "logprob": float(next_token_logprob), + } + output_dict[r["request_id"]] = (int(next_token_id), metadata) + + batch.input_ids = torch.tensor(new_input_ids, dtype=torch.long).cuda() + batch.nopad_total_token_num += len(batch) + batch.nopad_max_len_in_batch += 1 # NOTE: we may repalce this + self.cache[batch.batch_id] = batch + return output_dict + + @torch.no_grad() + def _prefill_batch(self, batch_id): + return self.forward(batch_id, is_prefill=True) + + @torch.no_grad() + def _decode_batch(self, batch_id): + return self.forward(batch_id, is_prefill=False) + + # might want to create a sequence pool + # add a single request/sequence/input text at a time and record its length + # In other words, store the actual length of input tokens representing a single input text + # E.g. "Introduce landmarks in Beijing" + # => add request + # => record token length and other necessary information to be used + # => engine hold all these necessary information until `generate` (or other name) is called, + # => put information already recorded in batchinferstate and pass it to model forward + # => clear records in engine + def add_request(): + raise NotImplementedError() diff --git a/colossalai/legacy/inference/tensor_parallel/kvcache_manager.py b/colossalai/legacy/inference/tensor_parallel/kvcache_manager.py new file mode 100644 index 000000000000..91bb96a1f1f0 --- /dev/null +++ b/colossalai/legacy/inference/tensor_parallel/kvcache_manager.py @@ -0,0 +1,106 @@ +""" +Refered/Modified from lightllm/common/mem_manager.py +of the ModelTC/lightllm GitHub repository +https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py +we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design. +""" +import torch +from transformers.utils import logging + + +class MemoryManager: + r""" + Manage token block indexes and allocate physical memory for key and value cache + + Args: + size: maximum token number used as the size of key and value buffer + dtype: data type of cached key and value + head_num: number of heads the memory manager is responsible for + head_dim: embedded size per head + layer_num: the number of layers in the model + device: device used to store the key and value cache + """ + + def __init__( + self, + size: int, + dtype: torch.dtype, + head_num: int, + head_dim: int, + layer_num: int, + device: torch.device = torch.device("cuda"), + ): + self.logger = logging.get_logger(__name__) + self.available_size = size + self.max_len_in_batch = 0 + self._init_mem_states(size, device) + self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) + + def _init_mem_states(self, size, device): + """Initialize tensors used to manage memory states""" + self.mem_state = torch.ones((size,), dtype=torch.bool, device=device) + self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device) + self.indexes = torch.arange(0, size, dtype=torch.long, device=device) + + def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num): + """Initialize key buffer and value buffer on specified device""" + self.key_buffer = [ + torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) + ] + self.value_buffer = [ + torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) + ] + + @torch.no_grad() + def alloc(self, required_size): + """allocate space of required_size by providing indexes representing available physical spaces""" + if required_size > self.available_size: + self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}") + return None + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) + select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1) + select_index = self.indexes[select_index] + self.mem_state[select_index] = 0 + self.available_size -= len(select_index) + return select_index + + @torch.no_grad() + def alloc_contiguous(self, required_size): + """allocate contiguous space of required_size""" + if required_size > self.available_size: + self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}") + return None + torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) + sum_size = len(self.mem_cum_sum) + loc_sums = ( + self.mem_cum_sum[required_size - 1 :] + - self.mem_cum_sum[0 : sum_size - required_size + 1] + + self.mem_state[0 : sum_size - required_size + 1] + ) + can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size] + if can_used_loc.shape[0] == 0: + self.logger.info( + f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}" + ) + return None + start_loc = can_used_loc[0] + select_index = self.indexes[start_loc : start_loc + required_size] + self.mem_state[select_index] = 0 + self.available_size -= len(select_index) + start = start_loc.item() + end = start + required_size + return select_index, start, end + + @torch.no_grad() + def free(self, free_index): + """free memory by updating memory states based on given indexes""" + self.available_size += free_index.shape[0] + self.mem_state[free_index] = 1 + + @torch.no_grad() + def free_all(self): + """free all memory by updating memory states""" + self.available_size = len(self.mem_state) + self.mem_state[:] = 1 + self.max_len_in_batch = 0 + self.logger.info("freed all space of memory manager") diff --git a/colossalai/legacy/inference/tensor_parallel/modeling/__init__.py b/colossalai/legacy/inference/tensor_parallel/modeling/__init__.py new file mode 100644 index 000000000000..4662368b17b4 --- /dev/null +++ b/colossalai/legacy/inference/tensor_parallel/modeling/__init__.py @@ -0,0 +1,5 @@ +from .bloom import BloomInferenceForwards +from .chatglm2 import ChatGLM2InferenceForwards +from .llama import LlamaInferenceForwards + +__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards", "ChatGLM2InferenceForwards"] diff --git a/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py b/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py new file mode 100644 index 000000000000..068b64b4f829 --- /dev/null +++ b/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py @@ -0,0 +1,67 @@ +""" +Utils for model inference +""" +import os + +import torch + +from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest + + +def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + """ + This function copies the key and value cache to the memory cache + Args: + layer_id : id of current layer + key_buffer : key cache + value_buffer : value cache + context_mem_index : index of memory cache in kv cache manager + mem_manager : cache manager + """ + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + + +def init_to_get_rotary(self, base=10000, use_elem=False): + """ + This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer + Args: + self : Model that holds the rotary positional embedding + base : calculation arg + use_elem : activated when using chatglm-based models + """ + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None) + + if ntk_alpha is not None: + ntk_alpha = float(ntk_alpha) + assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1" + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula + + n_elem = self.config.head_dim_ + if use_elem: + n_elem //= 2 + + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() diff --git a/colossalai/legacy/inference/tensor_parallel/modeling/bloom.py b/colossalai/legacy/inference/tensor_parallel/modeling/bloom.py new file mode 100644 index 000000000000..74fa5f470bf8 --- /dev/null +++ b/colossalai/legacy/inference/tensor_parallel/modeling/bloom.py @@ -0,0 +1,540 @@ +import math +import warnings +from typing import Optional, Tuple, Union + +import torch +import torch.distributed as dist +from torch.nn import CrossEntropyLoss +from torch.nn import functional as F +from transformers.models.bloom.modeling_bloom import ( + BaseModelOutputWithPastAndCrossAttentions, + BloomAttention, + BloomBlock, + BloomForCausalLM, + BloomModel, + CausalLMOutputWithCrossAttentions, +) +from transformers.utils import logging + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd + +try: + from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_bloom_context_attention_fwd, + ) + + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False + + +def generate_alibi(n_head, dtype=torch.float16): + """ + This method is adapted from `_generate_alibi` function + in `lightllm/models/bloom/layer_weights/transformer_layer_weight.py` + of the ModelTC/lightllm GitHub repository. + This method is originally the `build_alibi_tensor` function + in `transformers/models/bloom/modeling_bloom.py` + of the huggingface/transformers GitHub repository. + """ + + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + return [start * start**i for i in range(n)] + + def get_slopes(n): + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2) + slopes_double = get_slopes(2 * closest_power_of_2) + slopes_combined = slopes_power_of_2 + slopes_double[0::2][: n - closest_power_of_2] + return slopes_combined + + slopes = get_slopes(n_head) + return torch.tensor(slopes, dtype=dtype) + + +class BloomInferenceForwards: + """ + This class serves a micro library for bloom inference forwards. + We intend to replace the forward methods for BloomForCausalLM, BloomModel, BloomBlock, and BloomAttention, + as well as prepare_inputs_for_generation method for BloomForCausalLM. + For future improvement, we might want to skip replacing methods for BloomForCausalLM, + and call BloomModel.forward iteratively in TpInferEngine + """ + + @staticmethod + def bloom_model_forward( + self: BloomModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: Optional[BatchInferState] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + logger = logging.get_logger(__name__) + + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # still need to keep past_key_values to fit original forward flow + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # NOTE determine if BatchInferState is passed in via arg + # if not, get the attr binded to the model + # We might wantto remove setattr later + if infer_state is None: + assert hasattr(self, "infer_state") + infer_state = self.infer_state + + # infer_state.cache_manager = self.cache_manager + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 + + if use_cache and seq_length != 1: + # prefill stage + infer_state.is_context_stage = True # set prefill stage, notify attention layer + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + BatchInferState.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + infer_state.is_context_stage = False + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}") + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index + + if attention_mask is None: + attention_mask = torch.ones((batch_size, infer_state.max_len_in_batch), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + # NOTE revise: we might want to store a single 1D alibi(length is #heads) in model, + # or store to BatchInferState to prevent re-calculating + # When we have multiple process group (e.g. dp together with tp), we need to pass the pg to here + # alibi = generate_alibi(self.num_heads).contiguous().cuda() + tp_size = dist.get_world_size() + curr_tp_rank = dist.get_rank() + alibi = ( + generate_alibi(self.num_heads * tp_size) + .contiguous()[curr_tp_rank * self.num_heads : (curr_tp_rank + 1) * self.num_heads] + .cuda() + ) + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + infer_state.decode_layer_id = 0 + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + # NOTE: currently our KV cache manager does not handle this condition + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + layer_past, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + infer_state=infer_state, + ) + + infer_state.decode_layer_id += 1 + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # update indices of kv cache block + # NOT READY FOR PRIME TIME + # might want to remove this part, instead, better to pass the BatchInferState from model forward, + # and update these information in engine.generate after model foward called + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + infer_state.max_len_in_batch += 1 + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, # should always be (None, None, ..., None) + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + @staticmethod + def bloom_for_causal_lm_forward( + self: BloomForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: Optional[BatchInferState] = None, + **deprecated_arguments, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + logging.get_logger(__name__) + + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = BloomInferenceForwards.bloom_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + infer_state=infer_state, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + batch_size, seq_length, vocab_size = shift_logits.shape + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) + ) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def bloom_for_causal_lm_prepare_inputs_for_generation( + self: BloomForCausalLM, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> dict: + # only last token for input_ids if past is not None + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + + # NOTE we won't use past key values here + # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed + # if past_key_values[0][0].shape[0] == input_ids.shape[0]: + # past_key_values = self._convert_to_bloom_cache(past_key_values) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def bloom_block_forward( + self: BloomBlock, + hidden_states: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + infer_state: Optional[BatchInferState] = None, + ): + # hidden_states: [batch_size, seq_length, hidden_size] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + + # Layer norm post the self attention. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # Self attention. + attn_outputs = self.self_attention( + layernorm_output, + residual, + layer_past=layer_past, + attention_mask=attention_mask, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + infer_state=infer_state, + ) + + attention_output = attn_outputs[0] + + outputs = attn_outputs[1:] + + layernorm_output = self.post_attention_layernorm(attention_output) + + # Get residual + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = attention_output + + # MLP. + output = self.mlp(layernorm_output, residual) + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + @staticmethod + def bloom_attention_forward( + self: BloomAttention, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + infer_state: Optional[BatchInferState] = None, + ): + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + batch_size, q_length, H, D_HEAD = query_layer.shape + k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 + v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 + + mem_manager = infer_state.cache_manager + layer_id = infer_state.decode_layer_id + + if infer_state.is_context_stage: + # context process + max_input_len = q_length + b_start_loc = infer_state.start_loc + b_seq_len = infer_state.seq_len[:batch_size] + q = query_layer.reshape(-1, H, D_HEAD) + + copy_kv_cache_to_dest(k, infer_state.context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(v, infer_state.context_mem_index, mem_manager.value_buffer[layer_id]) + + # output = self.output[:batch_size*q_length, :, :] + output = torch.empty_like(q) + + if HAS_LIGHTLLM_KERNEL: + lightllm_bloom_context_attention_fwd(q, k, v, output, alibi, b_start_loc, b_seq_len, max_input_len) + else: + bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) + + context_layer = output.view(batch_size, q_length, H * D_HEAD) + else: + # query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + # need shape: batch_size, H, D_HEAD (q_length == 1), input q shape : (batch_size, q_length(1), H, D_HEAD) + assert q_length == 1, "for non-context process, we only support q_length == 1" + q = query_layer.reshape(-1, H, D_HEAD) + + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(k) + cache_v.copy_(v) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head] + copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id]) + + b_start_loc = infer_state.start_loc + b_loc = infer_state.block_loc + b_seq_len = infer_state.seq_len + output = torch.empty_like(q) + token_attention_fwd( + q, + mem_manager.key_buffer[layer_id], + mem_manager.value_buffer[layer_id], + output, + b_loc, + b_start_loc, + b_seq_len, + infer_state.max_len_in_batch, + alibi, + ) + + context_layer = output.view(batch_size, q_length, H * D_HEAD) + + # NOTE: always set present as none for now, instead of returning past key value to the next decoding, + # we create the past key value pair from the cache manager + present = None + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + # dropout is not required here during inference + output_tensor = residual + output_tensor + + outputs = (output_tensor, present) + assert output_attentions is False, "we do not support output_attentions at this time" + + return outputs diff --git a/colossalai/legacy/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/legacy/inference/tensor_parallel/modeling/chatglm2.py new file mode 100644 index 000000000000..b8fe8eb54855 --- /dev/null +++ b/colossalai/legacy/inference/tensor_parallel/modeling/chatglm2.py @@ -0,0 +1,545 @@ +import os +from typing import Optional, Tuple + +import torch +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( + ChatGLMForConditionalGeneration, + ChatGLMModel, + GLMBlock, + GLMTransformer, + SelfAttention, + split_tensor_along_last_dim, +) + +from ._utils import copy_kv_to_mem_cache + +try: + from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd + from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_llama2_context_attention_fwd, + ) + + HAS_LIGHTLLM_KERNEL = True +except: + print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") + HAS_LIGHTLLM_KERNEL = False + + +# This func is same as Llama model init_to_get_rotary, we should move them into _utils.py +def _init_to_get_rotary(self, base=10000): + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + try: + ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1)) + assert ntk_alpha >= 1 + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula + except: + pass + n_elem = self.config.head_dim_ // 2 + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + return + + +def get_masks(self, input_ids, past_length, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + if past_length: + full_attention_mask = torch.cat( + ( + torch.ones(batch_size, seq_length, past_length, device=input_ids.device), + full_attention_mask, + ), + dim=-1, + ) + + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + +class ChatGLM2InferenceForwards: + """ + This class holds forwards for Chatglm2 inference. + We intend to replace the forward methods for ChatGLMModel, ChatGLMEecoderLayer, and ChatGLMAttention. + """ + + @staticmethod + def chatglm_for_conditional_generation_forward( + self: ChatGLMForConditionalGeneration, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + infer_state = self.infer_state + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 + + seq_length_with_past = seq_length + past_key_values_length + + # prefill stage at first + if use_cache and seq_length != 1: + infer_state.is_context_stage = True + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + infer_state.is_context_stage = False + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + + # related to rotary embedding + if infer_state.is_context_stage: + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + else: + seq_len = infer_state.seq_len + infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + infer_state=infer_state, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def chatglm_model_forward( + self: ChatGLMModel, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: BatchInferState = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = get_masks( + self, input_ids, infer_state.cache_manager.past_key_values_length, padding_mask=attention_mask + ) + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + infer_state=infer_state, + ) + + # update indices + # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + infer_state.max_len_in_batch += 1 + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + @staticmethod + def chatglm_encoder_forward( + self: GLMTransformer, + hidden_states, + attention_mask, + kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + infer_state: Optional[BatchInferState] = None, + ): + hidden_states = hidden_states.transpose(0, 1).contiguous() + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + + infer_state.decode_layer_id = 0 + for index in range(self.num_layers): + layer = self.layers[index] + + layer_ret = layer( + hidden_states, + attention_mask, + kv_cache=kv_caches[index], + use_cache=use_cache, + infer_state=infer_state, + ) + + infer_state.decode_layer_id += 1 + + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + hidden_states = hidden_states.transpose(0, 1).contiguous() + + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + @staticmethod + def chatglm_glmblock_forward( + self: GLMBlock, + hidden_states, + attention_mask, + kv_cache=None, + use_cache=True, + infer_state: Optional[BatchInferState] = None, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + kv_cache=kv_cache, + use_cache=use_cache, + infer_state=infer_state, + ) + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + return output, kv_cache + + @staticmethod + def chatglm_flash_attn_kvcache_forward( + self: SelfAttention, + hidden_states, + attention_mask, + kv_cache=None, + use_cache=True, + infer_state: Optional[BatchInferState] = None, + ): + assert use_cache is True, "use_cache should be set to True using this chatglm attention" + # hidden_states: original :[sq, b, h] --> this [b, sq, h] + batch_size = hidden_states.shape[0] + hidden_size = hidden_states.shape[-1] + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + cos, sin = infer_state.position_cos, infer_state.position_sin + + chatglm2_rotary_emb_fwd( + query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin + ) + if self.multi_query_attention: + chatglm2_rotary_emb_fwd( + key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head), + cos, + sin, + ) + else: + chatglm2_rotary_emb_fwd( + key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), + cos, + sin, + ) + + # reshape q k v to [bsz*sql, num_heads, head_dim] 2*1 ,32/2 ,128 + query_layer = query_layer.reshape( + -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ) + key_layer = key_layer.reshape( + -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head + ) + value_layer = value_layer.reshape( + -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head + ) + + if infer_state.is_context_stage: + # first token generation: + # copy key and value calculated in current step to memory manager + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_layer, + value_layer, + infer_state.context_mem_index, + infer_state.cache_manager, + ) + attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size)) + + # NOTE: no bug in context attn fwd (del it ) + lightllm_llama2_context_attention_fwd( + query_layer, + key_layer, + value_layer, + attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), + infer_state.start_loc, + infer_state.seq_len, + infer_state.max_len_in_batch, + ) + + else: + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(key_layer) + cache_v.copy_(value_layer) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_layer, + value_layer, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) + + # second token and follows + attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size)) + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + : infer_state.decode_mem_end, :, : + ] + + # ================================== + # core attention computation is replaced by triton kernel + # ================================== + Llama2TokenAttentionForwards.token_attn( + query_layer, + cache_k, + cache_v, + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.max_len_in_batch, + infer_state.other_kv_index, + ) + + # print('after attention',torch.isnan(attn_output).any()) + + # ================= + # Output:[b,sq, h] + # ================= + output = self.dense(attn_output).reshape(batch_size, -1, hidden_size) + + return output, kv_cache diff --git a/colossalai/legacy/inference/tensor_parallel/modeling/llama.py b/colossalai/legacy/inference/tensor_parallel/modeling/llama.py new file mode 100644 index 000000000000..62c2aad3c055 --- /dev/null +++ b/colossalai/legacy/inference/tensor_parallel/modeling/llama.py @@ -0,0 +1,423 @@ +import math +from typing import List, Optional, Tuple + +import torch +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd +from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards + +from ._utils import copy_kv_to_mem_cache + +try: + from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_llama2_context_attention_fwd, + ) + from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_context_attention_fwd, + ) + from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd + + HAS_LIGHTLLM_KERNEL = True +except: + print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") + HAS_LIGHTLLM_KERNEL = False + +try: + from flash_attn import flash_attn_with_kvcache + + HAS_FLASH_KERNEL = True +except: + HAS_FLASH_KERNEL = False + print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention") + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def llama_triton_context_attention( + query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1 +): + if num_key_value_groups == 1: + if HAS_LIGHTLLM_KERNEL is False: + llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + lightllm_context_attention_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model" + lightllm_llama2_context_attention_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + + +def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1): + assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models" + if num_key_value_groups == 1: + token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + Llama2TokenAttentionForwards.token_attn( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + infer_state.other_kv_index, + ) + + +class LlamaInferenceForwards: + """ + This class holds forwards for llama inference. + We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM. + """ + + @staticmethod + def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + infer_state = self.infer_state + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + use_cache = use_cache if use_cache is not None else self.config.use_cache + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 + + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + if use_cache and seq_length != 1: + # NOTE assume prefill stage + # allocate memory block + infer_state.is_context_stage = True # set prefill stage, notify attention layer + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + infer_state.is_context_stage = False + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}") + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.repeat(batch_size, 1) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if infer_state.is_context_stage: + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + + else: + seq_len = infer_state.seq_len + infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=inputs_embeds.device + ) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + infer_state.decode_layer_id = 0 + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] if past_key_values is not None else None + # NOTE: modify here for passing args to decoder layer + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_state=infer_state, + ) + infer_state.decode_layer_id += 1 + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + hidden_states = self.norm(hidden_states) + next_cache = next_decoder_cache if use_cache else None + + # update indices + # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + infer_state.max_len_in_batch += 1 + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + @staticmethod + def llama_decoder_layer_forward( + self: LlamaDecoderLayer, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_state=infer_state, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + @staticmethod + def llama_flash_attn_kvcache_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + assert use_cache is True, "use_cache should be set to True using this llama attention" + + bsz, q_len, _ = hidden_states.size() + + # NOTE might think about better way to handle transposed k and v + # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head] + # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + # NOTE might want to revise + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + + cos, sin = infer_state.position_cos, infer_state.position_sin + + llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) + llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin) + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) + key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim) + + if infer_state.is_context_stage: + # first token generation + # copy key and value calculated in current step to memory manager + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.context_mem_index, + infer_state.cache_manager, + ) + attn_output = torch.empty_like(query_states) + + llama_triton_context_attention( + query_states, + key_states, + value_states, + attn_output, + infer_state, + num_key_value_groups=self.num_key_value_groups, + ) + else: + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(key_states) + cache_v.copy_(value_states) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) + + if HAS_LIGHTLLM_KERNEL: + attn_output = torch.empty_like(query_states) + llama_triton_token_attention( + query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups + ) + else: + self.num_heads // self.num_key_value_heads + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] + + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) + copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) + copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) + + attn_output = flash_attn_with_kvcache( + q=query_states, + k_cache=copy_cache_k, + v_cache=copy_cache_v, + softmax_scale=1 / math.sqrt(self.head_dim), + causal=True, + ) + + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + # return past_key_value as None + return attn_output, None, None diff --git a/colossalai/legacy/inference/tensor_parallel/policies/__init__.py b/colossalai/legacy/inference/tensor_parallel/policies/__init__.py new file mode 100644 index 000000000000..776c4e850565 --- /dev/null +++ b/colossalai/legacy/inference/tensor_parallel/policies/__init__.py @@ -0,0 +1,5 @@ +from .bloom import BloomModelInferPolicy +from .chatglm2 import ChatGLM2InferPolicy +from .llama import LlamaModelInferPolicy + +__all__ = ["BloomModelInferPolicy", "LlamaModelInferPolicy", "ChatGLM2InferPolicy"] diff --git a/colossalai/legacy/inference/tensor_parallel/policies/bloom.py b/colossalai/legacy/inference/tensor_parallel/policies/bloom.py new file mode 100644 index 000000000000..fba83a08175d --- /dev/null +++ b/colossalai/legacy/inference/tensor_parallel/policies/bloom.py @@ -0,0 +1,101 @@ +from functools import partial + +import torch +from torch.nn import LayerNorm + +import colossalai.shardformer.layer as col_nn +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription +from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy + +from ..modeling.bloom import BloomInferenceForwards + +try: + from colossalai.kernel.triton import layer_norm + + HAS_TRITON_NORM = True +except: + print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton") + HAS_TRITON_NORM = False + + +def get_triton_layernorm_forward(): + if HAS_TRITON_NORM: + + def _triton_layernorm_forward(self: LayerNorm, hidden_states: torch.Tensor): + return layer_norm(hidden_states, self.weight.data, self.bias, self.eps) + + return _triton_layernorm_forward + else: + return None + + +class BloomModelInferPolicy(BloomForCausalLMPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel + + policy = super().module_policy() + if self.shard_config.inference_gptq: + from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear + + policy[BloomBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attention.hidden_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attention.split_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 3}, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1} + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1} + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1} + ), + ], + ) + # NOTE set inference mode to shard config + self.shard_config._infer() + + method_replacement = { + "forward": BloomInferenceForwards.bloom_for_causal_lm_forward, + "prepare_inputs_for_generation": BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation, + } + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=BloomForCausalLM + ) + + method_replacement = {"forward": BloomInferenceForwards.bloom_model_forward} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel) + + method_replacement = {"forward": BloomInferenceForwards.bloom_block_forward} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock) + + method_replacement = {"forward": BloomInferenceForwards.bloom_attention_forward} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=BloomAttention + ) + + if HAS_TRITON_NORM: + infer_method = get_triton_layernorm_forward() + method_replacement = {"forward": partial(infer_method)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LayerNorm + ) + + return policy diff --git a/colossalai/legacy/inference/tensor_parallel/policies/chatglm2.py b/colossalai/legacy/inference/tensor_parallel/policies/chatglm2.py new file mode 100644 index 000000000000..60dc511f5e96 --- /dev/null +++ b/colossalai/legacy/inference/tensor_parallel/policies/chatglm2.py @@ -0,0 +1,77 @@ +from functools import partial + +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( + ChatGLMForConditionalGeneration, + ChatGLMModel, + GLMBlock, + GLMTransformer, + SelfAttention, +) + +# import colossalai +from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy + +from ..modeling._utils import init_to_get_rotary +from ..modeling.chatglm2 import ChatGLM2InferenceForwards + +try: + HAS_TRITON_RMSNORM = True +except: + print("you should install triton from https://github.com/openai/triton") + HAS_TRITON_RMSNORM = False + + +class ChatGLM2InferPolicy(ChatGLMModelPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + self.shard_config._infer() + + model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward + method_replacement = {"forward": model_infer_forward} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel) + + encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward + method_replacement = {"forward": encoder_infer_forward} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=GLMTransformer + ) + + encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward + method_replacement = {"forward": encoder_layer_infer_forward} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock) + + attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward + method_replacement = {"forward": attn_infer_forward} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=SelfAttention + ) + if self.shard_config.enable_tensor_parallelism: + policy[GLMBlock].attribute_replacement["self_attention.num_multi_query_groups_per_partition"] = ( + self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size + ) + # for rmsnorm and others, we need to check the shape + return policy + + def postprocess(self): + init_to_get_rotary(self.model) + return self.model + + +class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward + method_replacement = {"forward": partial(model_infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=ChatGLMForConditionalGeneration + ) + return policy + + def postprocess(self): + return super().postprocess() diff --git a/colossalai/legacy/inference/tensor_parallel/policies/llama.py b/colossalai/legacy/inference/tensor_parallel/policies/llama.py new file mode 100644 index 000000000000..3acba22cd164 --- /dev/null +++ b/colossalai/legacy/inference/tensor_parallel/policies/llama.py @@ -0,0 +1,121 @@ +from functools import partial + +import torch +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm + +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription + +# import colossalai +from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy + +from ..modeling._utils import init_to_get_rotary +from ..modeling.llama import LlamaInferenceForwards + +try: + from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward + + HAS_TRITON_RMSNORM = True +except: + print("you should install triton from https://github.com/openai/triton") + HAS_TRITON_RMSNORM = False + + +def get_triton_rmsnorm_forward(): + if HAS_TRITON_RMSNORM: + + def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) + + return _triton_rmsnorm_forward + else: + return None + + +class LlamaModelInferPolicy(LlamaForCausalLMPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + + if self.shard_config.inference_gptq: + from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear + + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=RowCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=RowCaiQuantLinear, + kwargs={"split_num": 1}, + ), + ], + ) + + self.shard_config._infer() + + infer_forward = LlamaInferenceForwards.llama_model_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) + + infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaDecoderLayer + ) + + infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaAttention + ) + + infer_forward = None + if HAS_TRITON_RMSNORM: + infer_forward = get_triton_rmsnorm_forward() + + if infer_forward is not None: + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaRMSNorm + ) + + return policy + + def postprocess(self): + init_to_get_rotary(self.model.model) + return self.model diff --git a/tests/test_infer_ops/triton/test_token_attn_fwd.py b/tests/test_infer_ops/triton/test_token_attn_fwd.py index a7fc3d29b77a..4ee1a5fb1234 100644 --- a/tests/test_infer_ops/triton/test_token_attn_fwd.py +++ b/tests/test_infer_ops/triton/test_token_attn_fwd.py @@ -4,11 +4,20 @@ try: from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") + +import importlib.util + +HAS_LIGHTLLM_KERNEL = True + +if importlib.util.find_spec("lightllm") is None: + HAS_LIGHTLLM_KERNEL = False + TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) >= version.parse("11.6") @@ -25,7 +34,8 @@ def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): @pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" + not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_LIGHTLLM_KERNEL, + reason="triton requires cuda version to be higher than 11.4 or not install lightllm", ) def test(): Z, head_num, seq_len, head_dim = 22, 112 // 8, 2048, 128 From a6bff989fd20ac604c758baa5e3330337215eea9 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Fri, 10 Nov 2023 13:55:12 +0800 Subject: [PATCH 5/5] fix chatglm2 --- colossalai/inference/hybridengine/modeling/__init__.py | 3 ++- colossalai/inference/hybridengine/modeling/chatglm2.py | 2 +- colossalai/inference/hybridengine/polices/__init__.py | 2 +- .../inference/hybridengine/polices/{chatglm.py => chatglm2.py} | 0 4 files changed, 4 insertions(+), 3 deletions(-) rename colossalai/inference/hybridengine/polices/{chatglm.py => chatglm2.py} (100%) diff --git a/colossalai/inference/hybridengine/modeling/__init__.py b/colossalai/inference/hybridengine/modeling/__init__.py index a6603066ad51..8a9e9999d3c5 100644 --- a/colossalai/inference/hybridengine/modeling/__init__.py +++ b/colossalai/inference/hybridengine/modeling/__init__.py @@ -1,4 +1,5 @@ from .bloom import BloomInferenceForwards +from .chatglm2 import ChatGLM2InferenceForwards from .llama import LlamaInferenceForwards -__all__ = ["LlamaInferenceForwards", "BloomInferenceForwards"] +__all__ = ["LlamaInferenceForwards", "BloomInferenceForwards", "ChatGLM2InferenceForwards"] diff --git a/colossalai/inference/hybridengine/modeling/chatglm2.py b/colossalai/inference/hybridengine/modeling/chatglm2.py index 0110b9d9a285..7b78aea0f03b 100644 --- a/colossalai/inference/hybridengine/modeling/chatglm2.py +++ b/colossalai/inference/hybridengine/modeling/chatglm2.py @@ -3,7 +3,7 @@ import torch from transformers.utils import logging -from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.inference.kvcache_manager import BatchInferState from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig diff --git a/colossalai/inference/hybridengine/polices/__init__.py b/colossalai/inference/hybridengine/polices/__init__.py index 84dfb5aff773..eb7da8bd50fb 100644 --- a/colossalai/inference/hybridengine/polices/__init__.py +++ b/colossalai/inference/hybridengine/polices/__init__.py @@ -1,5 +1,5 @@ from .bloom import BloomModelInferPolicy -from .chatglm import ChatGLM2InferPolicy +from .chatglm2 import ChatGLM2InferPolicy from .llama import LlamaModelInferPolicy __all__ = ["LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"] diff --git a/colossalai/inference/hybridengine/polices/chatglm.py b/colossalai/inference/hybridengine/polices/chatglm2.py similarity index 100% rename from colossalai/inference/hybridengine/polices/chatglm.py rename to colossalai/inference/hybridengine/polices/chatglm2.py