Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions colossalai/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ This project was written from scratch but we learned a lot from several other gr
- [vLLM](https://github.com/vllm-project/vllm)
- [flash-attention](https://github.com/Dao-AILab/flash-attention)
- [HuggingFace](https://huggingface.co)
- [StreamingLLM](https://github.com/mit-han-lab/streaming-llm)
If you wish to cite relevant research papars, you can find the reference below.

```bibtex
Expand All @@ -301,4 +302,12 @@ If you wish to cite relevant research papars, you can find the reference below.
author={Dao, Tri},
year={2023}
}

# StreamingLLM
@article{xiao2023streamingllm,
title={Efficient Streaming Language Models with Attention Sinks},
author={Xiao, Guangxuan and Tian, Yuandong and Chen, Beidi and Han, Song and Lewis, Mike},
journal={arXiv},
year={2023}
}
```
39 changes: 38 additions & 1 deletion colossalai/inference/batch_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def __init__(
fd_interm_tensor=None,
device=None,
dtype=torch.float16,
enable_streamingllm: bool = False,
start_token_size: int = 4,
generated_token_size: int = 512,
):
self.num_heads = num_heads
self.head_dim = head_dim
Expand All @@ -45,12 +48,19 @@ def __init__(
self._use_spec_dec = False
self._num_tokens_to_verify = None

self.enable_streamingllm = enable_streamingllm
self.start_token_size = start_token_size
self.generated_token_size = generated_token_size

self._current_batch_size = 0
self._sequences_dict = dict()
self._sequences_indexes = dict() # deque(maxlen=self.max_batch_size)
self._sequence_lengths = torch.zeros((self.max_batch_size,), dtype=torch.int32)
self._sequence_lengths_helper = torch.zeros_like(self._sequence_lengths)
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size
if enable_streamingllm:
max_blocks_per_seq = (start_token_size + generated_token_size + block_size - 1) // block_size + 1
else:
max_blocks_per_seq = (self.max_length + block_size - 1) // block_size
self._block_tables = torch.full((self.max_batch_size, max_blocks_per_seq), -1, dtype=torch.int32)
self._block_tables_helper = torch.full_like(self._block_tables, -1)

Expand Down Expand Up @@ -109,6 +119,33 @@ def batch_token_ids(self) -> List[List[int]]:
out.append(seq.input_token_id + seq.output_token_id)
return out

def streamingllm_update_batch(self, start_token_size: int, generated_token_size: int):
"""
Update sequence_lengths and block_tables when it is necessary to swap out a block.
"""

updated_block_ids = []

if self.current_batch_size > 0:
need_update = False
sequence_lengths_list = self._sequence_lengths.tolist()
block_tables_list = self._block_tables[: self._current_batch_size].tolist()
for batch_id in range(self.current_batch_size):
# We assume that the start token occupies the entire first block.
if sequence_lengths_list[batch_id] == start_token_size + generated_token_size + self.block_size - 1:
need_update = True
sequence_lengths_list[batch_id] = start_token_size + generated_token_size - 1
block_id = block_tables_list[batch_id].pop(1)
updated_block_ids.append(block_id)
block_tables_list[batch_id].append(-1)
if need_update:
self._sequence_lengths = torch.tensor(
sequence_lengths_list, dtype=self._sequence_lengths.dtype, device=self.device
)
self._block_tables = torch.tensor(block_tables_list, dtype=self._block_tables.dtype, device=self.device)
Comment thread
isky-cd marked this conversation as resolved.

return updated_block_ids

def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
"""Set batch bucket to use speculatvie decoding.
This will notify the adjust the lengths of inputs during modeling,
Expand Down
32 changes: 27 additions & 5 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,9 @@ class InferenceConfig(RPC_PARAM):
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0.
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
block_size (int): The number of blocks in a logical block, defaults to 16.
Expand All @@ -176,10 +177,12 @@ class InferenceConfig(RPC_PARAM):
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
start_token_size(int): The size of the start tokens, when using StreamingLLM.
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
"""

# NOTE: arrange configs according to their importance and frequency of usage
Expand Down Expand Up @@ -208,6 +211,7 @@ class InferenceConfig(RPC_PARAM):
no_repeat_ngram_size: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
forced_eos_token_id: int = None
ignore_eos: bool = False

# speculative decoding configs
max_n_spec_tokens: int = 5
Expand All @@ -221,15 +225,19 @@ class InferenceConfig(RPC_PARAM):
pp_size: int = 1
micro_batch_size: int = 1
micro_batch_buffer_size: int = None
high_precision: Optional[bool] = False

# cuda kernel option
use_cuda_kernel: bool = False
high_precision: Optional[bool] = False

# cuda_graph
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
max_context_len_to_capture: int = 512
ignore_eos: bool = False

# StreamingLLM (sliding window attention with attention sinks)
enable_streamingllm: bool = False
start_token_size: int = 4
generated_token_size: int = 512

def __post_init__(self):
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
Expand Down Expand Up @@ -260,6 +268,20 @@ def _verify_config(self) -> None:
if self.dtype == torch.float32:
self.high_precision = False

# check StreamingLLM
assert (
self.start_token_size <= self.block_size
), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}."
assert (
self.generated_token_size % self.block_size == 0
), f"We assume that the generated_token_size should be a multiple of the block_size, got generated_token_size={self.generated_token_size}."
# Our StreamingLLM implementation (sliding window attention with attention sinks) references https://arxiv.org/pdf/2309.17453 and has been optimized
# based on our framework's kvcache management mechanism. According to the paper, a start_token_size of 4 is sufficient. Therefore,
# we assume the start_token_size is less than or equal to the block size. When the start_token_size is smaller than the block size,
# we fill the first block with the start_token_size and subsequently generated tokens, using these as the "start tokens."
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
self.start_token_size = self.block_size
Comment thread
isky-cd marked this conversation as resolved.

# check prompt template
if self.prompt_template is None:
return
Expand Down
12 changes: 12 additions & 0 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,11 @@ def add_request(
elif max_length is not None:
max_new_tokens = max_length - len(prompts_token_ids[i])

if not self.inference_config.enable_streamingllm:
assert (
self.inference_config.max_output_len >= max_new_tokens
), f"max_new_tokens={max_new_tokens} must be less than max_output_len={self.inference_config.max_output_len}."

sequence = Sequence(
request_id,
prompt,
Expand Down Expand Up @@ -754,6 +759,13 @@ def step(self) -> List[str]:
logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
if self.inference_config.pad_input:
logits = logits[:, -1, :]

if self.inference_config.enable_streamingllm:
updated_block_ids = batch.streamingllm_update_batch(
self.inference_config.start_token_size, self.inference_config.generated_token_size
)
self.request_handler.streamingllm_free_block_tables(updated_block_ids)

next_tokens = search_tokens(
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
)
Expand Down
12 changes: 12 additions & 0 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo
fd_interm_tensor=fd_inter_tensor,
dtype=self.dtype,
device=device,
enable_streamingllm=inference_config.enable_streamingllm,
start_token_size=inference_config.start_token_size,
generated_token_size=inference_config.generated_token_size,
)
self.prefill_bb = BatchBucket(
num_heads=model_config.num_attention_heads // inference_config.tp_size,
Expand All @@ -168,6 +171,9 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo
fd_interm_tensor=fd_inter_tensor,
dtype=self.dtype,
device=device,
enable_streamingllm=inference_config.enable_streamingllm,
start_token_size=inference_config.start_token_size,
generated_token_size=inference_config.generated_token_size,
)

def _init_cache(self, model_config):
Expand Down Expand Up @@ -350,6 +356,12 @@ def update(self):

return finished_seqs

def streamingllm_free_block_tables(self, updated_block_ids: List[int]):
"""
Free the block that needs to be swapped out.
"""
self.cache_manager.streamingllm_free_block_tables(updated_block_ids)


class RPCRequestHandler(RequestHandler):
"""
Expand Down
38 changes: 32 additions & 6 deletions colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,16 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig) -> N
self.max_output_length = config.max_output_len
# Cache block settings
self.block_size = config.block_size

# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
self.max_blocks_per_sequence = (
self.max_input_length + self.max_output_length + self.block_size - 1
) // self.block_size
if config.enable_streamingllm:
self.max_blocks_per_sequence = (
config.start_token_size + config.generated_token_size + self.block_size - 1
) // self.block_size + 1
else:
self.max_blocks_per_sequence = (
self.max_input_length + self.max_output_length + self.block_size - 1
) // self.block_size
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width

# Physical cache allocation
Expand Down Expand Up @@ -446,6 +452,20 @@ def clear_all(self) -> None:
self._available_blocks = self.num_blocks
self._block_states[:] = 1

def streamingllm_free_block_tables(self, updated_block_ids: List[int]):
"""
Free the block that needs to be swapped out.
"""
for global_block_id in updated_block_ids:
if global_block_id < 0:
return
block: CacheBlock = self._cache_blocks[global_block_id]
block.remove_ref()
if not block.has_ref():
block.allocated_size = 0
self._available_blocks += 1
self._block_states[global_block_id] = 1

def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the tensor corresponding to the cache block with the prompted id for a specific layer."""
return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx]
Expand Down Expand Up @@ -533,10 +553,16 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
self.max_output_length = config.max_output_len
# Cache block settings
self.block_size = config.block_size

# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
self.max_blocks_per_sequence = (
self.max_input_length + self.max_output_length + self.block_size - 1
) // self.block_size
if config.enable_streamingllm:
self.max_blocks_per_sequence = (
config.start_token_size + config.generated_token_size + self.block_size - 1
) // self.block_size + 1
else:
self.max_blocks_per_sequence = (
self.max_input_length + self.max_output_length + self.block_size - 1
) // self.block_size
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width

# Logical cache blocks allocation
Expand Down
24 changes: 24 additions & 0 deletions examples/inference/llama/llama_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def infer(args):
block_size=16,
tp_size=args.tp_size,
use_cuda_kernel=args.use_cuda_kernel,
enable_streamingllm=args.enable_streamingllm,
start_token_size=args.start_token_size,
generated_token_size=args.generated_token_size,
)
coordinator.print_on_master(f"Initializing Inference Engine...")
engine = InferenceEngine(model, tokenizer, inference_config, model_policy=POLICY_CLS(), verbose=True)
Expand All @@ -63,6 +66,8 @@ def infer(args):
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
no_repeat_ngram_size=args.no_repeat_ngram_size,
repetition_penalty=args.repetition_penalty,
)
coordinator.print_on_master(f"Generating...")
out = engine.generate(prompts=[args.prompt], generation_config=generation_config)
Expand Down Expand Up @@ -107,6 +112,25 @@ def infer(args):
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation")
parser.add_argument("--top_k", type=int, default=50, help="Top k for generation")
parser.add_argument("--top_p", type=float, default=1.0, help="Top p for generation")
parser.add_argument("--enable_streamingllm", action="store_true", help="Whether to use StreamingLLM")
parser.add_argument(
"--start_token_size", type=int, default=4, help="The size of the start_token, When using StreamingLLM,"
)
parser.add_argument(
"--generated_token_size", type=int, default=512, help="The size of the generated_token, When using StreamingLLM"
)
parser.add_argument(
"--no_repeat_ngram_size",
type=int,
default=0,
help="If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.",
)
parser.add_argument(
"--repetition_penalty",
type=float,
default=1.0,
help="The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.",
)
args = parser.parse_args()

infer(args)
Loading