From 829ccc28ce7680543075c59891c73500c0128960 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Wed, 22 May 2024 09:17:20 +0000 Subject: [PATCH 1/7] Add Streaming LLM --- colossalai/inference/batch_bucket.py | 39 ++++++++++++++++++- colossalai/inference/config.py | 23 ++++++++--- colossalai/inference/core/engine.py | 11 ++++++ colossalai/inference/core/request_handler.py | 12 ++++++ .../inference/kv_cache/kvcache_manager.py | 38 +++++++++++++++--- examples/inference/llama/llama_generation.py | 15 +++++-- 6 files changed, 123 insertions(+), 15 deletions(-) diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index f8571c0ca030..8b75ef6ae5d6 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -31,6 +31,9 @@ def __init__( fd_interm_tensor=None, device=None, dtype=torch.float16, + enable_streamingllm: bool = False, + start_token_size: int = 4, + generate_token_size: int = 512, ): self.num_heads = num_heads self.head_dim = head_dim @@ -45,12 +48,19 @@ def __init__( self._use_spec_dec = False self._num_tokens_to_verify = None + self.enable_streamingllm = enable_streamingllm + self.start_token_size = start_token_size + self.generate_token_size = generate_token_size + self._current_batch_size = 0 self._sequences_dict = dict() self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size) self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32) self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths) - max_blocks_per_seq = (self.max_length + block_size - 1) // block_size + if enable_streamingllm: + max_blocks_per_seq = (start_token_size + generate_token_size + block_size - 1) // block_size + 1 + else: + max_blocks_per_seq = (self.max_length + block_size - 1) // block_size self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32) self._block_tables_helper = torch.full_like(self._block_tables, -1) @@ -109,6 +119,33 @@ def batch_token_ids(self) -> List[List[int]]: out.append(seq.input_token_id + seq.output_token_id) return out + def streamingllm_update_batch(self, start_token_size: int, generate_token_size: int): + """ + Update sequence_lengths and block_tables when it is necessary to swap out a block. + """ + + updated_block_ids = [] + + if self.current_batch_size > 0: + need_update = False + sequence_lengths_list = self._sequence_lengths.tolist() + block_tables_list = self._block_tables.tolist() + for batch_id in range(self.current_batch_size): + # We assume that the start token occupies the entire first block. + if sequence_lengths_list[batch_id] == start_token_size + generate_token_size + self.block_size - 1: + need_update = True + sequence_lengths_list[batch_id] = start_token_size + generate_token_size + block_id = block_tables_list[batch_id].pop(1) + updated_block_ids.append(block_id) + block_tables_list[batch_id].append(-1) + if need_update: + self._sequence_lengths = torch.tensor( + sequence_lengths_list, dtype=self._sequence_lengths.dtype, device=self.device + ) + self._block_tables = torch.tensor(block_tables_list, dtype=self._block_tables.dtype, device=self.device) + + return updated_block_ids + def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None: """Set batch bucket to use speculatvie decoding. This will notify the adjust the lengths of inputs during modeling, diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 61bc7c8abc9c..06613c07c1cb 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -166,8 +166,9 @@ class InferenceConfig(RPC_PARAM): top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None. top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None. temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0. - repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0. no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences. + repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0. + ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. n_spec_tokens (int): The maximum number of speculating tokens, defaults to None. glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False. block_size (int): The number of blocks in a logical block, defaults to 16. @@ -176,10 +177,12 @@ class InferenceConfig(RPC_PARAM): micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid. max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence - high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. - ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. + enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation. + start_token_size(int): The size of the start token, When using StreamingLLM, + generate_token_size(int): """ # NOTE: arrange configs according to their importance and frequency of usage @@ -208,6 +211,7 @@ class InferenceConfig(RPC_PARAM): no_repeat_ngram_size: Optional[int] = 0 repetition_penalty: Optional[float] = 1.0 forced_eos_token_id: int = None + ignore_eos: bool = False # speculative decoding configs max_n_spec_tokens: int = 5 @@ -221,18 +225,27 @@ class InferenceConfig(RPC_PARAM): pp_size: int = 1 micro_batch_size: int = 1 micro_batch_buffer_size: int = None - high_precision: Optional[bool] = False # cuda kernel option use_cuda_kernel: bool = False + high_precision: Optional[bool] = False # cuda_graph use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference max_context_len_to_capture: int = 512 - ignore_eos: bool = False + + # StreamingLLM + enable_streamingllm: bool = False + start_token_size: int = 4 + generate_token_size: int = 512 def __post_init__(self): self.max_context_len_to_capture = self.max_input_len + self.max_output_len + assert ( + self.start_token_size <= self.block_size + ), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}." + # We assume that start_token_size occupies one block. + self.start_token_size = self.block_size self._verify_config() def _verify_config(self) -> None: diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 96c2b15ee16e..e8f02e332c4b 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -667,6 +667,10 @@ def add_request( elif max_length is not None: max_new_tokens = max_length - len(prompts_token_ids[i]) + assert ( + self.inference_config.max_output_len >= max_new_tokens + ), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}." + sequence = Sequence( request_id, prompt, @@ -754,6 +758,13 @@ def step(self) -> List[str]: logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) if self.inference_config.pad_input: logits = logits[:, -1, :] + + if self.inference_config.enable_streamingllm: + updated_block_ids = batch.streamingllm_update_batch( + self.inference_config.start_token_size, self.inference_config.generate_token_size + ) + self.request_handler.streamingllm_free_block_tables(updated_block_ids) + next_tokens = search_tokens( self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids ) diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 5085c55558b4..823a1ca1bbdf 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -157,6 +157,9 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo fd_interm_tensor=fd_inter_tensor, dtype=self.dtype, device=device, + enable_streamingllm=inference_config.enable_streamingllm, + start_token_size=inference_config.start_token_size, + generate_token_size=inference_config.generate_token_size, ) self.prefill_bb = BatchBucket( num_heads=model_config.num_attention_heads // inference_config.tp_size, @@ -168,6 +171,9 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo fd_interm_tensor=fd_inter_tensor, dtype=self.dtype, device=device, + enable_streamingllm=inference_config.enable_streamingllm, + start_token_size=inference_config.start_token_size, + generate_token_size=inference_config.generate_token_size, ) def _init_cache(self, model_config): @@ -350,6 +356,12 @@ def update(self): return finished_seqs + def streamingllm_free_block_tables(self, updated_block_ids: List[int]): + """ + Free the block that needs to be swapped out. + """ + self.cache_manager.streamingllm_free_block_tables(updated_block_ids) + class RPCRequestHandler(RequestHandler): """ diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index a20bd8ee79ea..fd7325d4c8b7 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -78,10 +78,16 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig) -> N self.max_output_length = config.max_output_len # Cache block settings self.block_size = config.block_size + # NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size - self.max_blocks_per_sequence = ( - self.max_input_length + self.max_output_length + self.block_size - 1 - ) // self.block_size + if config.enable_streamingllm: + self.max_blocks_per_sequence = ( + config.start_token_size + config.generate_token_size + self.block_size - 1 + ) // self.block_size + 1 + else: + self.max_blocks_per_sequence = ( + self.max_input_length + self.max_output_length + self.block_size - 1 + ) // self.block_size self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width # Physical cache allocation @@ -446,6 +452,20 @@ def clear_all(self) -> None: self._available_blocks = self.num_blocks self._block_states[:] = 1 + def streamingllm_free_block_tables(self, updated_block_ids: List[int]): + """ + Free the block that needs to be swapped out. + """ + for global_block_id in updated_block_ids: + if global_block_id < 0: + return + block: CacheBlock = self._cache_blocks[global_block_id] + block.remove_ref() + if not block.has_ref(): + block.allocated_size = 0 + self._available_blocks += 1 + self._block_states[global_block_id] = 1 + def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: """Get the tensor corresponding to the cache block with the prompted id for a specific layer.""" return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx] @@ -533,10 +553,16 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb self.max_output_length = config.max_output_len # Cache block settings self.block_size = config.block_size + # NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size - self.max_blocks_per_sequence = ( - self.max_input_length + self.max_output_length + self.block_size - 1 - ) // self.block_size + if config.enable_streamingllm: + self.max_blocks_per_sequence = ( + config.start_token_size + config.generate_token_size + self.block_size - 1 + ) // self.block_size + 1 + else: + self.max_blocks_per_sequence = ( + self.max_input_length + self.max_output_length + self.block_size - 1 + ) // self.block_size self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width # Logical cache blocks allocation diff --git a/examples/inference/llama/llama_generation.py b/examples/inference/llama/llama_generation.py index c0a1a585a1b9..5d7f63353623 100644 --- a/examples/inference/llama/llama_generation.py +++ b/examples/inference/llama/llama_generation.py @@ -29,6 +29,11 @@ def infer(args): tokenizer.pad_token = tokenizer.eos_token # coordinator.print_on_master(f"Model Config:\n{model.config}") + prompts = [ + "介绍一下北京,", + "介绍一下武汉,", + ] + # ============================== # Initialize InferenceEngine # ============================== @@ -41,9 +46,11 @@ def infer(args): block_size=16, tp_size=args.tp_size, use_cuda_kernel=args.use_cuda_kernel, + enable_streamingllm=False, + generate_token_size=64, ) coordinator.print_on_master(f"Initializing Inference Engine...") - engine = InferenceEngine(model, tokenizer, inference_config, model_policy=POLICY_CLS(), verbose=True) + engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) # ============================== # Generation @@ -56,9 +63,11 @@ def infer(args): temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, + no_repeat_ngram_size=1, + repetition_penalty=1.5, ) coordinator.print_on_master(f"Generating...") - out = engine.generate(prompts=[args.prompt], generation_config=generation_config) + out = engine.generate(prompts=prompts, generation_config=generation_config) coordinator.print_on_master(out) # ============================== @@ -70,7 +79,7 @@ def infer(args): # turn on speculative decoding with the drafter model engine.enable_spec_dec(drafter_model) coordinator.print_on_master(f"Generating...") - out = engine.generate(prompts=[args.prompt], generation_config=generation_config) + out = engine.generate(prompts=prompts, generation_config=generation_config) coordinator.print_on_master(out) engine.disable_spec_dec() From c3f4edb3bd4d5140bc664cba1668eb87d8aa8ec9 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Thu, 23 May 2024 02:53:11 +0000 Subject: [PATCH 2/7] add some parameters to llama_generation.py --- colossalai/inference/batch_bucket.py | 12 +++--- colossalai/inference/config.py | 6 +-- colossalai/inference/core/engine.py | 2 +- colossalai/inference/core/request_handler.py | 4 +- .../inference/kv_cache/kvcache_manager.py | 4 +- examples/inference/llama/llama_generation.py | 39 +++++++++++++------ 6 files changed, 41 insertions(+), 26 deletions(-) diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index 8b75ef6ae5d6..1214e82a5aac 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -33,7 +33,7 @@ def __init__( dtype=torch.float16, enable_streamingllm: bool = False, start_token_size: int = 4, - generate_token_size: int = 512, + generated_token_size: int = 512, ): self.num_heads = num_heads self.head_dim = head_dim @@ -50,7 +50,7 @@ def __init__( self.enable_streamingllm = enable_streamingllm self.start_token_size = start_token_size - self.generate_token_size = generate_token_size + self.generated_token_size = generated_token_size self._current_batch_size = 0 self._sequences_dict = dict() @@ -58,7 +58,7 @@ def __init__( self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32) self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths) if enable_streamingllm: - max_blocks_per_seq = (start_token_size + generate_token_size + block_size - 1) // block_size + 1 + max_blocks_per_seq = (start_token_size + generated_token_size + block_size - 1) // block_size + 1 else: max_blocks_per_seq = (self.max_length + block_size - 1) // block_size self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32) @@ -119,7 +119,7 @@ def batch_token_ids(self) -> List[List[int]]: out.append(seq.input_token_id + seq.output_token_id) return out - def streamingllm_update_batch(self, start_token_size: int, generate_token_size: int): + def streamingllm_update_batch(self, start_token_size: int, generated_token_size: int): """ Update sequence_lengths and block_tables when it is necessary to swap out a block. """ @@ -132,9 +132,9 @@ def streamingllm_update_batch(self, start_token_size: int, generate_token_size: block_tables_list = self._block_tables.tolist() for batch_id in range(self.current_batch_size): # We assume that the start token occupies the entire first block. - if sequence_lengths_list[batch_id] == start_token_size + generate_token_size + self.block_size - 1: + if sequence_lengths_list[batch_id] == start_token_size + generated_token_size + self.block_size - 1: need_update = True - sequence_lengths_list[batch_id] = start_token_size + generate_token_size + sequence_lengths_list[batch_id] = start_token_size + generated_token_size block_id = block_tables_list[batch_id].pop(1) updated_block_ids.append(block_id) block_tables_list[batch_id].append(-1) diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 06613c07c1cb..9fa630d2c4da 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -181,8 +181,8 @@ class InferenceConfig(RPC_PARAM): use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid. max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation. - start_token_size(int): The size of the start token, When using StreamingLLM, - generate_token_size(int): + start_token_size(int): The size of the start_token, When using StreamingLLM. + generated_token_size(int): The size of the generated_token, When using StreamingLLM. """ # NOTE: arrange configs according to their importance and frequency of usage @@ -237,7 +237,7 @@ class InferenceConfig(RPC_PARAM): # StreamingLLM enable_streamingllm: bool = False start_token_size: int = 4 - generate_token_size: int = 512 + generated_token_size: int = 512 def __post_init__(self): self.max_context_len_to_capture = self.max_input_len + self.max_output_len diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index e8f02e332c4b..8a65bfc3c481 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -761,7 +761,7 @@ def step(self) -> List[str]: if self.inference_config.enable_streamingllm: updated_block_ids = batch.streamingllm_update_batch( - self.inference_config.start_token_size, self.inference_config.generate_token_size + self.inference_config.start_token_size, self.inference_config.generated_token_size ) self.request_handler.streamingllm_free_block_tables(updated_block_ids) diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 823a1ca1bbdf..512eaea71c7b 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -159,7 +159,7 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo device=device, enable_streamingllm=inference_config.enable_streamingllm, start_token_size=inference_config.start_token_size, - generate_token_size=inference_config.generate_token_size, + generated_token_size=inference_config.generated_token_size, ) self.prefill_bb = BatchBucket( num_heads=model_config.num_attention_heads // inference_config.tp_size, @@ -173,7 +173,7 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo device=device, enable_streamingllm=inference_config.enable_streamingllm, start_token_size=inference_config.start_token_size, - generate_token_size=inference_config.generate_token_size, + generated_token_size=inference_config.generated_token_size, ) def _init_cache(self, model_config): diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index fd7325d4c8b7..378eb2ff9151 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -82,7 +82,7 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig) -> N # NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size if config.enable_streamingllm: self.max_blocks_per_sequence = ( - config.start_token_size + config.generate_token_size + self.block_size - 1 + config.start_token_size + config.generated_token_size + self.block_size - 1 ) // self.block_size + 1 else: self.max_blocks_per_sequence = ( @@ -557,7 +557,7 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb # NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size if config.enable_streamingllm: self.max_blocks_per_sequence = ( - config.start_token_size + config.generate_token_size + self.block_size - 1 + config.start_token_size + config.generated_token_size + self.block_size - 1 ) // self.block_size + 1 else: self.max_blocks_per_sequence = ( diff --git a/examples/inference/llama/llama_generation.py b/examples/inference/llama/llama_generation.py index 5d7f63353623..73db7dff623e 100644 --- a/examples/inference/llama/llama_generation.py +++ b/examples/inference/llama/llama_generation.py @@ -29,11 +29,6 @@ def infer(args): tokenizer.pad_token = tokenizer.eos_token # coordinator.print_on_master(f"Model Config:\n{model.config}") - prompts = [ - "介绍一下北京,", - "介绍一下武汉,", - ] - # ============================== # Initialize InferenceEngine # ============================== @@ -46,11 +41,12 @@ def infer(args): block_size=16, tp_size=args.tp_size, use_cuda_kernel=args.use_cuda_kernel, - enable_streamingllm=False, - generate_token_size=64, + enable_streamingllm=args.enable_streamingllm, + start_token_size=args.start_token_size, + generated_token_size=args.generated_token_size, ) coordinator.print_on_master(f"Initializing Inference Engine...") - engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) + engine = InferenceEngine(model, tokenizer, inference_config, model_policy=POLICY_CLS(), verbose=True) # ============================== # Generation @@ -63,11 +59,11 @@ def infer(args): temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, - no_repeat_ngram_size=1, - repetition_penalty=1.5, + no_repeat_ngram_size=args.no_repeat_ngram_size, + repetition_penalty=args.repetition_penalty, ) coordinator.print_on_master(f"Generating...") - out = engine.generate(prompts=prompts, generation_config=generation_config) + out = engine.generate(prompts=[args.prompt], generation_config=generation_config) coordinator.print_on_master(out) # ============================== @@ -79,7 +75,7 @@ def infer(args): # turn on speculative decoding with the drafter model engine.enable_spec_dec(drafter_model) coordinator.print_on_master(f"Generating...") - out = engine.generate(prompts=prompts, generation_config=generation_config) + out = engine.generate(prompts=[args.prompt], generation_config=generation_config) coordinator.print_on_master(out) engine.disable_spec_dec() @@ -109,6 +105,25 @@ def infer(args): parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation") parser.add_argument("--top_k", type=int, default=50, help="Top k for generation") parser.add_argument("--top_p", type=float, default=1.0, help="Top p for generation") + parser.add_argument("--enable_streamingllm", action="store_true", help="Whether to use StreamingLLM") + parser.add_argument( + "--start_token_size", type=int, default=4, help="The size of the start_token, When using StreamingLLM," + ) + parser.add_argument( + "--generated_token_size", type=int, default=512, help="The size of the generated_token, When using StreamingLLM" + ) + parser.add_argument( + "--no_repeat_ngram_size", + type=int, + default=0, + help="If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.", + ) + parser.add_argument( + "--repetition_penalty", + type=float, + default=1.0, + help="The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.", + ) args = parser.parse_args() infer(args) From d63e068d4abc98d0756f680ba1a158c6e4aea2d0 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Thu, 23 May 2024 03:46:00 +0000 Subject: [PATCH 3/7] verify streamingllm config --- colossalai/inference/config.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 9fa630d2c4da..1515ca6d9eec 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -241,11 +241,6 @@ class InferenceConfig(RPC_PARAM): def __post_init__(self): self.max_context_len_to_capture = self.max_input_len + self.max_output_len - assert ( - self.start_token_size <= self.block_size - ), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}." - # We assume that start_token_size occupies one block. - self.start_token_size = self.block_size self._verify_config() def _verify_config(self) -> None: @@ -285,6 +280,15 @@ def _verify_config(self) -> None: "{input_text}" in self.prompt_template ), "The prompt template should contain '{input_text}' for formatting the input text. For example: 'USER: {input_text}\n\nASSISTANT: '" + assert ( + self.start_token_size <= self.block_size + ), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}." + assert ( + self.generated_token_size % self.block_size == 0 + ), f"We assume that the generated_token_size should be a multiple of the block_size, got generated_token_size={self.generated_token_size}." + # We assume that start_token_size occupies one block. + self.start_token_size = self.block_size + def to_generation_config(self, model_config) -> GenerationConfig: meta_config = { "max_length": self.max_input_len + self.max_output_len, From 21a573aded392d5dc24c4ae8507ef1d57040eebf Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Mon, 3 Jun 2024 08:32:32 +0000 Subject: [PATCH 4/7] add test_streamingllm.py --- colossalai/inference/batch_bucket.py | 2 +- colossalai/inference/config.py | 19 ++-- tests/test_infer/test_streamingllm.py | 129 ++++++++++++++++++++++++++ 3 files changed, 140 insertions(+), 10 deletions(-) create mode 100644 tests/test_infer/test_streamingllm.py diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index 1214e82a5aac..65c7b84309b4 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -134,7 +134,7 @@ def streamingllm_update_batch(self, start_token_size: int, generated_token_size: # We assume that the start token occupies the entire first block. if sequence_lengths_list[batch_id] == start_token_size + generated_token_size + self.block_size - 1: need_update = True - sequence_lengths_list[batch_id] = start_token_size + generated_token_size + sequence_lengths_list[batch_id] = start_token_size + generated_token_size - 1 block_id = block_tables_list[batch_id].pop(1) updated_block_ids.append(block_id) block_tables_list[batch_id].append(-1) diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 1515ca6d9eec..b76f93ba5573 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -268,6 +268,16 @@ def _verify_config(self) -> None: if self.dtype == torch.float32: self.high_precision = False + # check streamingLLM + assert ( + self.start_token_size <= self.block_size + ), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}." + assert ( + self.generated_token_size % self.block_size == 0 + ), f"We assume that the generated_token_size should be a multiple of the block_size, got generated_token_size={self.generated_token_size}." + # We assume that start_token_size occupies one block. + self.start_token_size = self.block_size + # check prompt template if self.prompt_template is None: return @@ -280,15 +290,6 @@ def _verify_config(self) -> None: "{input_text}" in self.prompt_template ), "The prompt template should contain '{input_text}' for formatting the input text. For example: 'USER: {input_text}\n\nASSISTANT: '" - assert ( - self.start_token_size <= self.block_size - ), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}." - assert ( - self.generated_token_size % self.block_size == 0 - ), f"We assume that the generated_token_size should be a multiple of the block_size, got generated_token_size={self.generated_token_size}." - # We assume that start_token_size occupies one block. - self.start_token_size = self.block_size - def to_generation_config(self, model_config) -> GenerationConfig: meta_config = { "max_length": self.max_input_len + self.max_output_len, diff --git a/tests/test_infer/test_streamingllm.py b/tests/test_infer/test_streamingllm.py new file mode 100644 index 000000000000..f90e9af46c5d --- /dev/null +++ b/tests/test_infer/test_streamingllm.py @@ -0,0 +1,129 @@ +import random + +import numpy as np +import pytest +import torch +import torch.distributed as dist +from torch.multiprocessing import Manager +from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM + +import colossalai +from colossalai.inference.config import InferenceConfig +from colossalai.inference.core.engine import InferenceEngine +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def data_gen(batch_size: int = 4, seq_len: int = 512): + input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=torch.cuda.current_device()) + return input_ids + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def check_streamingllm(): + setup_seed(20) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + model = LlamaForCausalLM( + LlamaConfig( + vocab_size=50000, + hidden_size=512, + intermediate_size=1536, + num_attention_heads=4, + num_key_value_heads=2, + num_hidden_layers=16, + ) + ).cuda() + model = model.eval() + + input_token_ids = data_gen(1, 4) + + output_len = 128 + + inference_config = InferenceConfig( + max_batch_size=1, + max_output_len=output_len, + dtype="fp32", + use_cuda_kernel=True, + tp_size=dist.get_world_size(), + enable_streamingllm=True, + start_token_size=4, + generated_token_size=32, + ) + + print("inference_config.start_token_size: ", inference_config.start_token_size) + + # assert inference_config.start_token_size == inference_config.block_size + inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) + assert inference_engine.generation_config.max_new_tokens == output_len + inference_engine.add_request(prompts_token_ids=input_token_ids) + assert inference_engine.request_handler._has_waiting() + + assert inference_config.start_token_size == inference_config.block_size + + request_handler = inference_engine.request_handler + running_bb = request_handler.running_bb + + for _ in range(12): + inference_engine.step() + + assert running_bb.block_tables[0].tolist() == [0, -1, -1, -1] + assert running_bb.seq_lengths[0].item() == 16 + + for _ in range(16): + inference_engine.step() + + assert running_bb.block_tables[0].tolist() == [0, 1, -1, -1] + assert running_bb.seq_lengths[0].item() == 32 + + for _ in range(16): + inference_engine.step() + + assert running_bb.block_tables[0].tolist() == [0, 1, 2, -1] + assert running_bb.seq_lengths[0].item() == 48 + + for _ in range(16): + inference_engine.step() + + assert running_bb.block_tables[0].tolist() == [0, 2, 3, -1] + assert running_bb.seq_lengths[0].item() == 48 + + for _ in range(1): + inference_engine.step() + + assert running_bb.block_tables[0].tolist() == [0, 2, 3, 1] + assert running_bb.seq_lengths[0].item() == 49 + + for _ in range(15): + inference_engine.step() + + assert running_bb.block_tables[0].tolist() == [0, 3, 1, -1] + assert running_bb.seq_lengths[0].item() == 48 + + +def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") + + if ret: + ret[rank] = func_to_run(**kwargs) + else: + func_to_run(**kwargs) + + +@pytest.mark.largedist +@rerun_if_address_is_in_use() +def test_tp_engine(): + manager = Manager() + result_list = manager.list([-1] * 1) # Create a shared list + + spawn(run_dist, 1, func_to_run=check_streamingllm, ret=result_list) + return result_list[0] + + +if __name__ == "__main__": + test_tp_engine() From fb8116187019b91617436c19377d07205004ceff Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Tue, 4 Jun 2024 03:40:01 +0000 Subject: [PATCH 5/7] modified according to the opinions of review --- colossalai/inference/batch_bucket.py | 2 +- colossalai/inference/config.py | 12 ++++++++---- tests/test_infer/test_streamingllm.py | 11 ++--------- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index 65c7b84309b4..0406be15eac6 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -129,7 +129,7 @@ def streamingllm_update_batch(self, start_token_size: int, generated_token_size: if self.current_batch_size > 0: need_update = False sequence_lengths_list = self._sequence_lengths.tolist() - block_tables_list = self._block_tables.tolist() + block_tables_list = self._block_tables[: self._current_batch_size - 1].tolist() for batch_id in range(self.current_batch_size): # We assume that the start token occupies the entire first block. if sequence_lengths_list[batch_id] == start_token_size + generated_token_size + self.block_size - 1: diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index b76f93ba5573..49340a8a58a4 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -181,8 +181,8 @@ class InferenceConfig(RPC_PARAM): use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid. max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation. - start_token_size(int): The size of the start_token, When using StreamingLLM. - generated_token_size(int): The size of the generated_token, When using StreamingLLM. + start_token_size(int): The size of the start tokens, when using StreamingLLM. + generated_token_size(int): The size of the generated tokens, When using StreamingLLM. """ # NOTE: arrange configs according to their importance and frequency of usage @@ -234,7 +234,7 @@ class InferenceConfig(RPC_PARAM): use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference max_context_len_to_capture: int = 512 - # StreamingLLM + # StreamingLLM (sliding window attention with attention sinks) enable_streamingllm: bool = False start_token_size: int = 4 generated_token_size: int = 512 @@ -275,7 +275,11 @@ def _verify_config(self) -> None: assert ( self.generated_token_size % self.block_size == 0 ), f"We assume that the generated_token_size should be a multiple of the block_size, got generated_token_size={self.generated_token_size}." - # We assume that start_token_size occupies one block. + # Our streamingLLM implementation (sliding window attention with attention sinks) references https://arxiv.org/pdf/2309.17453 and has been optimized + # based on our framework's kvcache management mechanism. According to the paper, a start_token_size of 4 is sufficient. Therefore, + # we assume the start_token_size is less than or equal to the block size. When the start_token_size is smaller than the block size, + # we fill the first block with the start_token_size and subsequently generated tokens, using these as the "start tokens." + # Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit. self.start_token_size = self.block_size # check prompt template diff --git a/tests/test_infer/test_streamingllm.py b/tests/test_infer/test_streamingllm.py index f90e9af46c5d..f8b6487f1019 100644 --- a/tests/test_infer/test_streamingllm.py +++ b/tests/test_infer/test_streamingllm.py @@ -1,9 +1,7 @@ import random import numpy as np -import pytest import torch -import torch.distributed as dist from torch.multiprocessing import Manager from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM @@ -50,15 +48,11 @@ def check_streamingllm(): max_output_len=output_len, dtype="fp32", use_cuda_kernel=True, - tp_size=dist.get_world_size(), enable_streamingllm=True, start_token_size=4, generated_token_size=32, ) - print("inference_config.start_token_size: ", inference_config.start_token_size) - - # assert inference_config.start_token_size == inference_config.block_size inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts_token_ids=input_token_ids) @@ -115,9 +109,8 @@ def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs): func_to_run(**kwargs) -@pytest.mark.largedist @rerun_if_address_is_in_use() -def test_tp_engine(): +def test_engine(): manager = Manager() result_list = manager.list([-1] * 1) # Create a shared list @@ -126,4 +119,4 @@ def test_tp_engine(): if __name__ == "__main__": - test_tp_engine() + test_engine() From e94fc470fab142356980b0f99a0b4b0c6b7f70e7 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Tue, 4 Jun 2024 05:41:31 +0000 Subject: [PATCH 6/7] add Citation --- colossalai/inference/README.md | 9 +++++++++ colossalai/inference/config.py | 4 ++-- colossalai/inference/core/engine.py | 7 ++++--- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index b46222d806af..ec40441274cc 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -278,6 +278,7 @@ This project was written from scratch but we learned a lot from several other gr - [vLLM](https://github.com/vllm-project/vllm) - [flash-attention](https://github.com/Dao-AILab/flash-attention) - [HuggingFace](https://huggingface.co) +- [StreamingLLM](https://github.com/mit-han-lab/streaming-llm) If you wish to cite relevant research papars, you can find the reference below. ```bibtex @@ -301,4 +302,12 @@ If you wish to cite relevant research papars, you can find the reference below. author={Dao, Tri}, year={2023} } + +# StreamingLLM +@article{xiao2023streamingllm, + title={Efficient Streaming Language Models with Attention Sinks}, + author={Xiao, Guangxuan and Tian, Yuandong and Chen, Beidi and Han, Song and Lewis, Mike}, + journal={arXiv}, + year={2023} +} ``` diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 49340a8a58a4..9cf9a65e6007 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -268,14 +268,14 @@ def _verify_config(self) -> None: if self.dtype == torch.float32: self.high_precision = False - # check streamingLLM + # check StreamingLLM assert ( self.start_token_size <= self.block_size ), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}." assert ( self.generated_token_size % self.block_size == 0 ), f"We assume that the generated_token_size should be a multiple of the block_size, got generated_token_size={self.generated_token_size}." - # Our streamingLLM implementation (sliding window attention with attention sinks) references https://arxiv.org/pdf/2309.17453 and has been optimized + # Our StreamingLLM implementation (sliding window attention with attention sinks) references https://arxiv.org/pdf/2309.17453 and has been optimized # based on our framework's kvcache management mechanism. According to the paper, a start_token_size of 4 is sufficient. Therefore, # we assume the start_token_size is less than or equal to the block size. When the start_token_size is smaller than the block size, # we fill the first block with the start_token_size and subsequently generated tokens, using these as the "start tokens." diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 8a65bfc3c481..1b6e62553bc6 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -667,9 +667,10 @@ def add_request( elif max_length is not None: max_new_tokens = max_length - len(prompts_token_ids[i]) - assert ( - self.inference_config.max_output_len >= max_new_tokens - ), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}." + if not self.inference_config.enable_streamingllm: + assert ( + self.inference_config.max_output_len >= max_new_tokens + ), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}." sequence = Sequence( request_id, From ccf737ffef3509ef2dfe12eabd7a0ba81a489763 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Tue, 4 Jun 2024 05:49:09 +0000 Subject: [PATCH 7/7] change _block_tables tolist --- colossalai/inference/batch_bucket.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/inference/batch_bucket.py b/colossalai/inference/batch_bucket.py index 0406be15eac6..88bde3a3beeb 100644 --- a/colossalai/inference/batch_bucket.py +++ b/colossalai/inference/batch_bucket.py @@ -129,7 +129,7 @@ def streamingllm_update_batch(self, start_token_size: int, generated_token_size: if self.current_batch_size > 0: need_update = False sequence_lengths_list = self._sequence_lengths.tolist() - block_tables_list = self._block_tables[: self._current_batch_size - 1].tolist() + block_tables_list = self._block_tables[: self._current_batch_size].tolist() for batch_id in range(self.current_batch_size): # We assume that the start token occupies the entire first block. if sequence_lengths_list[batch_id] == start_token_size + generated_token_size + self.block_size - 1: