From 70303d9f9222237c8e34231c77988a822f58046b Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 24 Oct 2023 13:59:43 +0800 Subject: [PATCH 01/16] adding flash-decoding --- .../tensor_parallel/modeling/llama.py | 37 ++++++++++++++----- index.html | 2 + 2 files changed, 29 insertions(+), 10 deletions(-) create mode 100644 index.html diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index a3937f6f10ba..f6d613afdbcf 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -35,6 +35,18 @@ print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") HAS_LIGHTLLM_KERNEL = False +try: + from xformers.ops import RMSNorm, fmha, rope_padded + from xformers.ops.fmha.attn_bias import ( + BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias, + ) + HAS_XFORMERS = True +except: + print("please install xformers from source to run inference:") + HAS_XFORMERS = False + + + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -355,16 +367,21 @@ def llama_flash_attn_kvcache_forward( attn_output = torch.empty_like(query_states) if self.num_key_value_groups == 1: - token_attention_fwd( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, - ) + # token_attention_fwd( + # query_states, + # infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + # infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + # attn_output, + # infer_state.block_loc, + # infer_state.start_loc, + # infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + # ) + attn_output = fmha.memory_efficient_attention_forward(query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attention_mask + ) else: Llama2TokenAttentionForwards.token_attn( query_states, diff --git a/index.html b/index.html new file mode 100644 index 000000000000..2cbc927f9a77 --- /dev/null +++ b/index.html @@ -0,0 +1,2 @@ + + 百度一下,你就知道

关于百度 About Baidu

©2017 Baidu 使用百度前必读  意见反馈 京ICP证030173号 

From 35465d8da1e352f86a2b4fd4644a11fc183be7cb Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 24 Oct 2023 13:59:56 +0800 Subject: [PATCH 02/16] clean --- index.html | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 index.html diff --git a/index.html b/index.html deleted file mode 100644 index 2cbc927f9a77..000000000000 --- a/index.html +++ /dev/null @@ -1,2 +0,0 @@ - - 百度一下,你就知道

关于百度 About Baidu

©2017 Baidu 使用百度前必读  意见反馈 京ICP证030173号 

From fe1986d4f3a3fde371dfc0296a255ab6e6242ac2 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 24 Oct 2023 17:13:00 +0800 Subject: [PATCH 03/16] adding kernel --- .../tensor_parallel/modeling/llama.py | 43 ++++++------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index f6d613afdbcf..ec9c38e14a24 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -221,7 +221,8 @@ def llama_model_forward( hidden_states=all_hidden_states, attentions=all_self_attns, ) - + + @staticmethod def llama_decoder_layer_forward( self: LlamaDecoderLayer, @@ -265,6 +266,7 @@ def llama_decoder_layer_forward( outputs += (present_key_value,) return outputs + @staticmethod def llama_flash_attn_kvcache_forward( @@ -366,34 +368,17 @@ def llama_flash_attn_kvcache_forward( # (batch_size, seqlen, nheads, headdim) attn_output = torch.empty_like(query_states) - if self.num_key_value_groups == 1: - # token_attention_fwd( - # query_states, - # infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - # infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - # attn_output, - # infer_state.block_loc, - # infer_state.start_loc, - # infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, - # ) - attn_output = fmha.memory_efficient_attention_forward(query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attention_mask - ) - else: - Llama2TokenAttentionForwards.token_attn( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, - infer_state.other_kv_index, - ) + + heads_per_group = self.num_heads // self.num_key_value_heads + query_states = query_states.view(bsz, q_len, self.num_key_value_heads, heads_per_group, self.head_dim) + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] + + cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, 1, self.head_dim) + cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, 1, self.head_dim) + + + attn_output = fmha.memory_efficient_attention_forward(query_states, cache_k, cache_v, None) attn_output = attn_output.view(bsz, q_len, self.hidden_size) From 7c91a9aeb416b6073c60a1b52776749a2cc6cd4e Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 24 Oct 2023 18:10:23 +0800 Subject: [PATCH 04/16] adding flash-decoding --- colossalai/inference/README.md | 15 +++++++++++++-- .../inference/tensor_parallel/modeling/llama.py | 2 +- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index d0c281e057b3..f1a3104bc54e 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -34,11 +34,13 @@ In this section we discuss how the colossal inference works and integrates with - [x] policy - [x] context forward - [x] token forward + - [] support flash-decoding - [ ] Replace the kernels with `faster-transformer` in token-forward stage - [ ] Support all models - [x] Llama + - [x] Llama-2 - [x] Bloom - - [ ] Chatglm2 + - [x] Chatglm2 - [ ] Benchmarking for all models ## Get started @@ -68,6 +70,12 @@ git clone https://github.com/ModelTC/lightllm git checkout 28c1267cfca536b7b4f28e921e03de735b003039 cd lightllm pip3 install -e . + +# also, install xformers from source: +pip install ninja +# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types +pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers + ``` ### Docker @@ -89,7 +97,10 @@ git checkout 28c1267cfca536b7b4f28e921e03de735b003039 cd lightllm pip3 install -e . - +# install xformers from source +pip install ninja +# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types +pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers ``` ### Dive into fast-inference! diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index ec9c38e14a24..e32e30da739d 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -42,7 +42,7 @@ ) HAS_XFORMERS = True except: - print("please install xformers from source to run inference:") + print("please install xformers from source to run inference: https://github.com/facebookresearch/xformers") HAS_XFORMERS = False From 35bb8b13d3673a2fb1210883eb770f944f02232e Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 24 Oct 2023 19:16:47 +0800 Subject: [PATCH 05/16] add integration --- .../tensor_parallel/modeling/llama.py | 52 +++++++++++++------ 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index e32e30da739d..f4c2251b0eae 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -362,23 +362,43 @@ def llama_flash_attn_kvcache_forward( infer_state.decode_mem_index, infer_state.cache_manager, ) - - # second token and follows - # kv = torch.stack((key_states, value_states), dim=2) - # (batch_size, seqlen, nheads, headdim) - attn_output = torch.empty_like(query_states) - - - heads_per_group = self.num_heads // self.num_key_value_heads - query_states = query_states.view(bsz, q_len, self.num_key_value_heads, heads_per_group, self.head_dim) - cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] - cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] - cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, 1, self.head_dim) - cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, 1, self.head_dim) - - - attn_output = fmha.memory_efficient_attention_forward(query_states, cache_k, cache_v, None) + if attention_mask is not None: + attn_output = torch.empty_like(query_states) + heads_per_group = self.num_heads // self.num_key_value_heads + query_states = query_states.view(bsz, q_len, self.num_key_value_heads, heads_per_group, self.head_dim) + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] + + cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, 1, self.head_dim) + cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, 1, self.head_dim) + + + attn_output = fmha.memory_efficient_attention_forward(query_states, cache_k, cache_v, None) + + elif self.num_key_value_groups == 1: + token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + else: + Llama2TokenAttentionForwards.token_attn( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + infer_state.other_kv_index, + ) attn_output = attn_output.view(bsz, q_len, self.hidden_size) From d01b9ac8bf997423f0fffb4646d7568502c6a7f9 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 24 Oct 2023 19:18:14 +0800 Subject: [PATCH 06/16] add --- colossalai/inference/tensor_parallel/modeling/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index f4c2251b0eae..6e59fdaa4331 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -363,7 +363,7 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager, ) - if attention_mask is not None: + if attention_mask is None: attn_output = torch.empty_like(query_states) heads_per_group = self.num_heads // self.num_key_value_heads query_states = query_states.view(bsz, q_len, self.num_key_value_heads, heads_per_group, self.head_dim) From 783b9e0068e2820b78905e40fde22f07413b841c Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 25 Oct 2023 14:26:38 +0800 Subject: [PATCH 07/16] adding kernel --- .../tensor_parallel/modeling/llama.py | 62 +++++++++---------- examples/inference/bench_llama.py | 4 +- 2 files changed, 32 insertions(+), 34 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 6e59fdaa4331..213dab0c7bf5 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -1,4 +1,5 @@ from typing import List, Optional, Tuple +import math import torch from transformers.modeling_outputs import BaseModelOutputWithPast @@ -36,16 +37,11 @@ HAS_LIGHTLLM_KERNEL = False try: - from xformers.ops import RMSNorm, fmha, rope_padded - from xformers.ops.fmha.attn_bias import ( - BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias, - ) - HAS_XFORMERS = True + from flash_attn import flash_attn_with_kvcache + HAS_FLASH_KERNEL = True except: - print("please install xformers from source to run inference: https://github.com/facebookresearch/xformers") - HAS_XFORMERS = False - - + HAS_FLASH_KERNEL = False + print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention") def rotate_half(x): @@ -362,32 +358,34 @@ def llama_flash_attn_kvcache_forward( infer_state.decode_mem_index, infer_state.cache_manager, ) - - if attention_mask is None: - attn_output = torch.empty_like(query_states) - heads_per_group = self.num_heads // self.num_key_value_heads - query_states = query_states.view(bsz, q_len, self.num_key_value_heads, heads_per_group, self.head_dim) - cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] - cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] - cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, 1, self.head_dim) - cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, 1, self.head_dim) - - - attn_output = fmha.memory_efficient_attention_forward(query_states, cache_k, cache_v, None) + if self.num_key_value_groups == 1: + if HAS_FLASH_KERNEL: + attn_output = torch.empty_like(query_states) + heads_per_group = self.num_heads // self.num_key_value_heads + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] + + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) + cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) + cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) + + attn_output = flash_attn_with_kvcache(q = query_states, k_cache = cache_k, v_cache = cache_v, softmax_scale = 1/ math.sqrt(self.head_dim), causal = True) - elif self.num_key_value_groups == 1: - token_attention_fwd( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, - ) + else: + attn_output = torch.empty_like(query_states) + token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) else: + attn_output = torch.empty_like(query_states) Llama2TokenAttentionForwards.token_attn( query_states, infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index 0ca1953c6a41..077dd1634fca 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -121,8 +121,8 @@ def test_llama(args): parser = argparse.ArgumentParser() parser.add_argument("-p", "--path", type=str, help="Model path", required=True) parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") - parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") - parser.add_argument("--input_len", type=int, default=256, help="Maximum input length") + parser.add_argument("-b", "--batch_size", type=int, default=2, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=128, help="Maximum input length") parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") parser.add_argument( "--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"] From 1d9596b237f7bfcdc999403b257badd14ccb3c43 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 25 Oct 2023 15:05:49 +0800 Subject: [PATCH 08/16] adding kernel --- colossalai/inference/tensor_parallel/modeling/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 213dab0c7bf5..60530ab6f022 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -360,8 +360,7 @@ def llama_flash_attn_kvcache_forward( ) if self.num_key_value_groups == 1: - if HAS_FLASH_KERNEL: - attn_output = torch.empty_like(query_states) + if HAS_FLASH_KERNEL: heads_per_group = self.num_heads // self.num_key_value_heads cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] From 03f15a1d06a6a0daa5b5432f3f3bb34d7a1b4825 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Fri, 27 Oct 2023 14:49:27 +0800 Subject: [PATCH 09/16] adding triton 2.1.0 features for inference --- .../tensor_parallel/modeling/llama.py | 69 +++++---- colossalai/kernel/triton/context_attention.py | 137 +++++++++--------- 2 files changed, 115 insertions(+), 91 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 60530ab6f022..aeee9bb2e4af 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -1,5 +1,6 @@ from typing import List, Optional, Tuple import math +import copy import torch from transformers.modeling_outputs import BaseModelOutputWithPast @@ -13,7 +14,6 @@ try: from vllm import layernorm_ops, pos_encoding_ops - rms_norm = layernorm_ops.rms_norm rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox HAS_VLLM_KERNERL = True @@ -29,6 +29,7 @@ from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( context_attention_fwd as lightllm_llama2_context_attention_fwd, ) + from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_context_attention_fwd from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd HAS_LIGHTLLM_KERNEL = True @@ -62,6 +63,40 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed +def llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1): + if num_key_value_groups == 1: + if HAS_LIGHTLLM_KERNEL is False: + llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + else: + lightllm_context_attention_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + else: + assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model" + lightllm_llama2_context_attention_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + class LlamaInferenceForwards: """ @@ -314,29 +349,9 @@ def llama_flash_attn_kvcache_forward( infer_state.context_mem_index, infer_state.cache_manager, ) - attn_output = torch.empty_like(query_states) - - if self.num_key_value_groups == 1: - llama_context_attn_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, - ) - else: - lightllm_llama2_context_attention_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, - ) + + llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups) else: if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly @@ -366,10 +381,12 @@ def llama_flash_attn_kvcache_forward( cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) - cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) - cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) + copy_cache_k= copy.deepcopy(cache_k) + copy_cache_v = copy.deepcopy(cache_v) + copy_cache_k = copy_cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) + copy_cache_v = copy_cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) - attn_output = flash_attn_with_kvcache(q = query_states, k_cache = cache_k, v_cache = cache_v, softmax_scale = 1/ math.sqrt(self.head_dim), causal = True) + attn_output = flash_attn_with_kvcache(q = query_states, k_cache = copy_cache_k, v_cache = copy_cache_v, softmax_scale = 1/ math.sqrt(self.head_dim), causal = True) else: attn_output = torch.empty_like(query_states) diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 1b4f6e44b0f2..5ce6f2c21385 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -5,7 +5,6 @@ try: import triton import triton.language as tl - HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -155,39 +154,43 @@ def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, al num_warps = 4 if Lk <= 64 else 8 tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) - - _context_flash_attention_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - tmp, - alibi, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - tmp.stride(0), - tmp.stride(1), - tmp.stride(2), - # manually setting this blcok num, we can use tuning config to futher speed-up - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) + + if triton.__version__ < "2.1.0": + _context_flash_attention_kernel[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + tmp, + alibi, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + # manually setting this blcok num, we can use tuning config to futher speed-up + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + else: + raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0") + return @torch.no_grad() @@ -207,36 +210,40 @@ def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 # num_warps = 4 - _context_flash_attention_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - tmp, - None, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - tmp.stride(0), - tmp.stride(1), - tmp.stride(2), - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) + if triton.__version__ < "2.1.0": + _context_flash_attention_kernel[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + tmp, + None, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + else: + raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0") + return \ No newline at end of file From 703eae8152d4650990f7787fec43eae07be2f1f2 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Fri, 27 Oct 2023 15:58:43 +0800 Subject: [PATCH 10/16] update bloom triton kernel --- .../tensor_parallel/modeling/bloom.py | 11 ++- .../tensor_parallel/modeling/llama.py | 81 ++++++++++--------- 2 files changed, 53 insertions(+), 39 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index 27a26caabefa..4d5db4b7691a 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -19,6 +19,12 @@ from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd +try: + from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_bloom_context_attention_fwd + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False + def generate_alibi(n_head, dtype=torch.float16): """ @@ -469,7 +475,10 @@ def bloom_attention_forward( # output = self.output[:batch_size*q_length, :, :] output = torch.empty_like(q) - bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) + if HAS_LIGHTLLM_KERNEL: + lightllm_bloom_context_attention_fwd(q, k, v, output, alibi, b_start_loc, b_seq_len, max_input_len) + else: + bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) context_layer = output.view(batch_size, q_length, H * D_HEAD) else: diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index aeee9bb2e4af..de01971878ec 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -97,6 +97,32 @@ def llama_triton_context_attention(query_states, key_states, value_states, attn_ infer_state.cache_manager.past_key_values_length, ) +def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1): + assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models" + if num_key_value_groups == 1: + token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + else: + Llama2TokenAttentionForwards.token_attn( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + infer_state.other_kv_index, + ) + class LlamaInferenceForwards: """ @@ -329,7 +355,6 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager.past_key_values_length += q_len # seq_len cos, sin = infer_state.position_cos, infer_state.position_sin - # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin) @@ -374,45 +399,25 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager, ) - if self.num_key_value_groups == 1: - if HAS_FLASH_KERNEL: - heads_per_group = self.num_heads // self.num_key_value_heads - cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] - cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] - - query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) - copy_cache_k= copy.deepcopy(cache_k) - copy_cache_v = copy.deepcopy(cache_v) - copy_cache_k = copy_cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) - copy_cache_v = copy_cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) - - attn_output = flash_attn_with_kvcache(q = query_states, k_cache = copy_cache_k, v_cache = copy_cache_v, softmax_scale = 1/ math.sqrt(self.head_dim), causal = True) + if HAS_FLASH_KERNEL: + heads_per_group = self.num_heads // self.num_key_value_heads + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] - else: - attn_output = torch.empty_like(query_states) - token_attention_fwd( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, - ) - else: + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) + copy_cache_k= copy.deepcopy(cache_k) + copy_cache_v = copy.deepcopy(cache_v) + copy_cache_k = copy_cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) + copy_cache_v = copy_cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) + + attn_output = flash_attn_with_kvcache(q = query_states, + k_cache = copy_cache_k, + v_cache = copy_cache_v, + softmax_scale = 1/ math.sqrt(self.head_dim), + causal = True) + else: attn_output = torch.empty_like(query_states) - Llama2TokenAttentionForwards.token_attn( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, - infer_state.other_kv_index, - ) + llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups) attn_output = attn_output.view(bsz, q_len, self.hidden_size) From 1e7ec04637586e48808d35f2f29633afa36ca71b Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Fri, 27 Oct 2023 16:02:05 +0800 Subject: [PATCH 11/16] remove useless vllm kernels --- .../tensor_parallel/modeling/llama.py | 32 ---- .../tensor_parallel/policies/llama.py | 3 - .../test_infer_ops/cuda/test_vllm_rmsnorm.py | 60 ------- .../cuda/test_vllm_rotary_embedding.py | 153 ------------------ 4 files changed, 248 deletions(-) delete mode 100644 tests/test_infer_ops/cuda/test_vllm_rmsnorm.py delete mode 100644 tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index de01971878ec..ffa48e8c2024 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -12,19 +12,6 @@ from ._utils import copy_kv_to_mem_cache -try: - from vllm import layernorm_ops, pos_encoding_ops - rms_norm = layernorm_ops.rms_norm - rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox - HAS_VLLM_KERNERL = True -except: - print("fall back to original rotary_embedding_neox of huggingface") - print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") - print( - "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch" - ) - HAS_VLLM_KERNERL = False - try: from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( context_attention_fwd as lightllm_llama2_context_attention_fwd, @@ -426,22 +413,3 @@ def llama_flash_attn_kvcache_forward( # return past_key_value as None return attn_output, None, None - -def get_llama_vllm_rmsnorm_forward(): - if HAS_VLLM_KERNERL: - - def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - x = hidden_states - out = torch.empty_like(x) - rms_norm( - out, - x, - self.weight.data, - self.variance_epsilon, - ) - - return out - - return _vllm_rmsnorm_forward - else: - return None diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index 7e163efe0173..4763fc2852f6 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -105,9 +105,6 @@ def module_policy(self): infer_forward = None if HAS_TRITON_RMSNORM: infer_forward = get_triton_rmsnorm_forward() - else: - # NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123 - infer_forward = get_llama_vllm_rmsnorm_forward() if infer_forward is not None: method_replacement = {"forward": partial(infer_forward)} diff --git a/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py b/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py deleted file mode 100644 index a4d893f8e830..000000000000 --- a/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py +++ /dev/null @@ -1,60 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- -import pytest -import torch -from torch import nn - -try: - from vllm import layernorm_ops - - rms_norm = layernorm_ops.rms_norm - HAS_VLLM_KERNERL = True -except: - print("please install vllm kernels to install rmsnorm") - print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") - HAS_VLLM_KERNERL = False - - -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon): - x = hidden_states - out = torch.empty_like(x) - rms_norm( - out, - x, - weight, - variance_epsilon, - ) - return out - - -@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test") -def test_rmsnorm(): - data = torch.randn((1024, 64), dtype=torch.float16, device="cuda") - hg_rms = LlamaRMSNorm(64) - hg_rms = hg_rms.half().cuda() - out_torch = hg_rms(data) - out_cuda = cuda_rmsnorm_forward(data, hg_rms.weight.data, hg_rms.variance_epsilon) - - check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5) - assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward" - - -if __name__ == "__main__": - test_rmsnorm() diff --git a/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py b/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py deleted file mode 100644 index 40451ef6636d..000000000000 --- a/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- -from typing import Tuple - -import pytest -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half - -try: - from vllm import pos_encoding_ops - - rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox - HAS_VLLM_KERNERL = True -except: - print("fall back to original rotary_embedding_neox of huggingface") - print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") - HAS_VLLM_KERNERL = False - - -def rotate_half(x: torch.Tensor) -> torch.Tensor: - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb( - q: torch.Tensor, - k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class RefRotaryEmbeddingNeox(nn.Module): - """Reference implementation of the GPT-NeoX style rotary embedding.""" - - def __init__( - self, - dim: int, - max_position_embeddings: int = 2048, - base: int = 10000, - ) -> None: - super().__init__() - self.rotary_dim = dim - self.max_position_embeddings = max_position_embeddings - - # Create cos and sin embeddings. - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim)) - t = torch.arange(max_position_embeddings).float() - freqs = torch.einsum("i,j->ij", t, inv_freq.float()) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos().to(dtype=inv_freq.dtype) - sin = emb.sin().to(dtype=inv_freq.dtype) - self.register_buffer("cos_cached", cos, persistent=False) - self.register_buffer("sin_cached", sin, persistent=False) - - def forward( - self, - positions: torch.Tensor, # [num_tokens] - query: torch.Tensor, # [num_tokens, num_heads, head_size] - key: torch.Tensor, # [num_tokens, num_heads, head_size] - ) -> Tuple[torch.Tensor, torch.Tensor]: - query_rot = query[..., : self.rotary_dim] - query_pass = query[..., self.rotary_dim :] - key_rot = key[..., : self.rotary_dim] - key_pass = key[..., self.rotary_dim :] - - query_rot = query_rot.transpose(0, 1) - key_rot = key_rot.transpose(0, 1) - cos = F.embedding(positions, self.cos_cached) - sin = F.embedding(positions, self.sin_cached) - query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) - query_rot = query_rot.transpose(0, 1).contiguous() - key_rot = key_rot.transpose(0, 1).contiguous() - - query = torch.cat((query_rot, query_pass), dim=-1) - key = torch.cat((key_rot, key_pass), dim=-1) - - # Output query/key shape: [num_tokens, num_tokens, head_size] - return query, key - - -def run_rotary_embedding_neox( - num_tokens: int, - num_heads: int, - head_size: int, - max_position: int, - rotary_dim: int, - dtype: torch.dtype, - base: int = 10000, -) -> None: - positions = torch.randint(0, max_position, (num_tokens,), device="cuda") - query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda") - key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda") - - # Create the rotary embedding. - inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim)) - t = torch.arange(max_position).float() - freqs = torch.einsum("i,j -> ij", t, inv_freq.float()) - cos = freqs.cos() - sin = freqs.sin() - cos_sin_cache = torch.cat((cos, sin), dim=-1) - cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda") - - # Run the kernel. The kernel is in-place, so we need to clone the inputs. - out_query = query.clone() - out_key = key.clone() - rotary_embedding_neox( - positions, - out_query, - out_key, - head_size, - cos_sin_cache, - ) - - # Run the reference implementation. - ref_rotary_embedding = RefRotaryEmbeddingNeox( - dim=rotary_dim, - max_position_embeddings=max_position, - base=base, - ).to(dtype=dtype, device="cuda") - ref_query, ref_key = ref_rotary_embedding( - positions, - query.view(num_tokens, num_heads, head_size), - key.view(num_tokens, num_heads, head_size), - ) - ref_query = ref_query.view(num_tokens, num_heads * head_size) - ref_key = ref_key.view(num_tokens, num_heads * head_size) - - # Compare the results. - assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5) - assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5) - - -@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test") -def test_rotary_embedding(): - run_rotary_embedding_neox( - num_tokens=1024, - num_heads=8, - head_size=64, - max_position=8192, - rotary_dim=64, - dtype=torch.float16, - ) - - -if __name__ == "__main__": - test_rotary_embedding() From 4f8b52a77534218f6f8e2bf7e4bb1dca9dd89552 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Fri, 27 Oct 2023 16:07:31 +0800 Subject: [PATCH 12/16] clean codes --- colossalai/inference/tensor_parallel/policies/llama.py | 2 +- tests/test_infer/test_bloom_infer.py | 8 +++++++- tests/test_infer/test_chatglm2_infer.py | 8 +++++++- tests/test_infer/test_llama2_infer.py | 8 +++++++- tests/test_infer/test_llama_infer.py | 8 +++++++- 5 files changed, 29 insertions(+), 5 deletions(-) diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index 4763fc2852f6..d6c072c747b7 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -9,7 +9,7 @@ from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy from ..modeling._utils import init_to_get_rotary -from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward +from ..modeling.llama import LlamaInferenceForwards try: from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index ba978ad9bf0d..d4366758d6a3 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -10,6 +10,12 @@ from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +try: + import lightllm + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False + TP_SIZE = 2 MAX_BATCH_SIZE = 4 MAX_INPUT_LEN = 16 @@ -52,7 +58,7 @@ def check_bloom(rank, world_size, port): run() -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index 399b70e1460e..02571fcfc2af 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -12,6 +12,12 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +try: + import lightllm + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False + os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" TPSIZE = 1 BATCH_SIZE = 8 @@ -62,7 +68,7 @@ def check_chatglm2(rank, world_size, port): run_chatglm2_test() -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_llama2_infer.py b/tests/test_infer/test_llama2_infer.py index 0eebed8892ea..13e7a61826ab 100644 --- a/tests/test_infer/test_llama2_infer.py +++ b/tests/test_infer/test_llama2_infer.py @@ -12,6 +12,12 @@ from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +try: + import lightllm + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False + os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" TPSIZE = 2 BATCH_SIZE = 8 @@ -57,7 +63,7 @@ def check_llama(rank, world_size, port): run_llama_test() -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index b424525a3719..a4f54d197065 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -12,6 +12,12 @@ from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +try: + import lightllm + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False + os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" TPSIZE = 2 BATCH_SIZE = 8 @@ -55,7 +61,7 @@ def check_llama(rank, world_size, port): run_llama_test() -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() From d6b142b0c61bb7ffbce93771fade5ab92c878de3 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Sun, 29 Oct 2023 16:04:02 +0800 Subject: [PATCH 13/16] fix --- colossalai/inference/tensor_parallel/engine.py | 1 + .../inference/tensor_parallel/modeling/llama.py | 13 +++++++------ examples/inference/bench_llama.py | 4 ++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 216b134f5fab..42db8217a229 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -309,6 +309,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState: seq_start_indexes[i] = start_index start_index += curr_seq_len max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda") batch_infer_state = BatchInferState(batch_size, max_len_in_batch) batch_infer_state.seq_len = seq_lengths.to("cuda") diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index ffa48e8c2024..6c06166f1aac 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -385,8 +385,11 @@ def llama_flash_attn_kvcache_forward( infer_state.decode_mem_index, infer_state.cache_manager, ) - - if HAS_FLASH_KERNEL: + + if HAS_LIGHTLLM_KERNEL: + attn_output = torch.empty_like(query_states) + llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups) + else: heads_per_group = self.num_heads // self.num_key_value_heads cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] @@ -401,10 +404,8 @@ def llama_flash_attn_kvcache_forward( k_cache = copy_cache_k, v_cache = copy_cache_v, softmax_scale = 1/ math.sqrt(self.head_dim), - causal = True) - else: - attn_output = torch.empty_like(query_states) - llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups) + causal = True) + attn_output = attn_output.view(bsz, q_len, self.hidden_size) diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index 077dd1634fca..4523726719f0 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -121,8 +121,8 @@ def test_llama(args): parser = argparse.ArgumentParser() parser.add_argument("-p", "--path", type=str, help="Model path", required=True) parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") - parser.add_argument("-b", "--batch_size", type=int, default=2, help="Maximum batch size") - parser.add_argument("--input_len", type=int, default=128, help="Maximum input length") + parser.add_argument("-b", "--batch_size", type=int, default=32, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") parser.add_argument( "--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"] From 0b720ef3c22ba100d95d2cf3e9eab87ec45ef47c Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 30 Oct 2023 11:28:30 +0800 Subject: [PATCH 14/16] adding files --- colossalai/inference/tensor_parallel/modeling/llama.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 6c06166f1aac..d18b9688e657 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -395,10 +395,8 @@ def llama_flash_attn_kvcache_forward( cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) - copy_cache_k= copy.deepcopy(cache_k) - copy_cache_v = copy.deepcopy(cache_v) - copy_cache_k = copy_cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) - copy_cache_v = copy_cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) + copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) + copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) attn_output = flash_attn_with_kvcache(q = query_states, k_cache = copy_cache_k, From af7d17f2f545af503f64d0d7d4423987468d1160 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 30 Oct 2023 11:49:48 +0800 Subject: [PATCH 15/16] fix readme --- colossalai/inference/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index f1a3104bc54e..4aca7aeb0582 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -34,7 +34,7 @@ In this section we discuss how the colossal inference works and integrates with - [x] policy - [x] context forward - [x] token forward - - [] support flash-decoding + - [x] support flash-decoding - [ ] Replace the kernels with `faster-transformer` in token-forward stage - [ ] Support all models - [x] Llama From 39db326162214a955d9863cbf4c79773f7380043 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 30 Oct 2023 12:49:18 +0800 Subject: [PATCH 16/16] update llama flash-decoding --- .../inference/tensor_parallel/modeling/llama.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 951e632c9469..8573bb965ea6 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -60,7 +60,8 @@ def llama_triton_context_attention(query_states, key_states, value_states, attn_ attn_output, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) else: lightllm_context_attention_fwd( @@ -70,7 +71,8 @@ def llama_triton_context_attention(query_states, key_states, value_states, attn_ attn_output, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) else: assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model" @@ -81,7 +83,8 @@ def llama_triton_context_attention(query_states, key_states, value_states, attn_ attn_output, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1): @@ -95,7 +98,8 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) else: Llama2TokenAttentionForwards.token_attn( @@ -106,7 +110,8 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, infer_state.other_kv_index, ) @@ -377,6 +382,7 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager, ) + HAS_LIGHTLLM_KERNEL = False if HAS_LIGHTLLM_KERNEL: attn_output = torch.empty_like(query_states) llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups)