From 39183dfda210b3e53f5f65fc77a3acb555a94dfd Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 20 Nov 2023 00:30:35 +0800 Subject: [PATCH 1/5] added flash-decoidng of triton based on lightllm kernel --- colossalai/inference/build.sh | 2 +- colossalai/inference/engine/modeling/llama.py | 32 +++++++++--- colossalai/kernel/triton/context_attention.py | 1 + colossalai/kernel/triton/flash_decoding.py | 50 +++++++++++++++++++ 4 files changed, 76 insertions(+), 9 deletions(-) create mode 100644 colossalai/kernel/triton/flash_decoding.py diff --git a/colossalai/inference/build.sh b/colossalai/inference/build.sh index 6a73f6f0b985..88a88d692dda 100644 --- a/colossalai/inference/build.sh +++ b/colossalai/inference/build.sh @@ -9,7 +9,7 @@ mkdir 3rdParty cd 3rdParty git clone https://github.com/ModelTC/lightllm cd lightllm -git checkout 28c1267cfca536b7b4f28e921e03de735b003039 +git checkout ece7b43f8a6dfa74027adc77c2c176cff28c76c8 pip install -e . cd .. diff --git a/colossalai/inference/engine/modeling/llama.py b/colossalai/inference/engine/modeling/llama.py index 2dd1858d60f5..e12b09dfd380 100644 --- a/colossalai/inference/engine/modeling/llama.py +++ b/colossalai/inference/engine/modeling/llama.py @@ -27,9 +27,15 @@ print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") HAS_LIGHTLLM_KERNEL = False +try: + from colossalai.kernel.triton.flash_decoding import token_flash_decoding + HAS_TRITON_FLASH_DECODING_KERNEL = True +except: + print("no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8") + HAS_TRITON_FLASH_DECODING_KERNEL = False + try: from flash_attn import flash_attn_with_kvcache - HAS_FLASH_KERNEL = True except: HAS_FLASH_KERNEL = False @@ -95,8 +101,17 @@ def llama_triton_context_attention( ) -def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1): - assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models" +def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num = -1, head_dim = -1): + if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1: + token_flash_decoding(q = query_states, + o_tensor = attn_output, + infer_state = infer_state, + q_head_num = q_head_num, + head_dim = head_dim, + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]) + return + if num_key_value_groups == 1: token_attention_fwd( query_states, @@ -106,7 +121,6 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, infer_state.max_len_in_batch, ) else: @@ -118,7 +132,6 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, infer_state.max_len_in_batch, infer_state.other_kv_index, ) @@ -452,9 +465,12 @@ def llama_flash_attn_kvcache_forward( if HAS_LIGHTLLM_KERNEL: attn_output = torch.empty_like(query_states) - llama_triton_token_attention( - query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups - ) + llama_triton_token_attention(query_states = query_states, + attn_output = attn_output, + infer_state = infer_state, + num_key_value_groups = self.num_key_value_groups, + q_head_num = q_len * self.num_heads, + head_dim = self.head_dim) else: self.num_heads // self.num_key_value_heads cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 1ad7a80eb5e7..3d9a23d2f5d2 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -137,6 +137,7 @@ def _context_flash_attention_kernel( tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return else: + # this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11 @triton.jit def _context_flash_attention_kernel_2( Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen, diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py new file mode 100644 index 000000000000..dbc5d6cf217a --- /dev/null +++ b/colossalai/kernel/triton/flash_decoding.py @@ -0,0 +1,50 @@ +import torch + +try: + from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1 + from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2 + HAS_LIGHTLLM_KERNEL = True +except: + print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8") + HAS_LIGHTLLM_KERNEL = False + +if HAS_LIGHTLLM_KERNEL: + # adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py + def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v): + BLOCK_SEQ = 256 + batch_size = infer_state.batch_size + max_len_in_batch = infer_state.max_len_in_batch + + + calcu_shape1 = (batch_size, q_head_num, head_dim) + + if getattr(infer_state, 'mid_o', None) is None: + infer_state.mid_o = torch.empty([batch_size, + q_head_num, + max_len_in_batch // BLOCK_SEQ + 1, + head_dim], + dtype=torch.float32, + device="cuda") + infer_state.mid_o_logexpsum = torch.empty([batch_size, + q_head_num, + max_len_in_batch // BLOCK_SEQ + 1], + dtype=torch.float32, + device="cuda") + + mid_o = infer_state.mid_o + mid_o_logexpsum = infer_state.mid_o_logexpsum + + flash_decode_stage1(q.view(calcu_shape1), + cache_k, + cache_v, + infer_state.block_loc, + infer_state.seq_len, + infer_state.max_len_in_batch, + mid_o, + mid_o_logexpsum, + BLOCK_SEQ) + flash_decode_stage2(mid_o, + mid_o_logexpsum, + infer_state.seq_len, + o_tensor.view(calcu_shape1), + BLOCK_SEQ) From b38ee81b91ba9da417d958f94f169d8f9ecd98e8 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 20 Nov 2023 10:58:44 +0800 Subject: [PATCH 2/5] add req --- colossalai/kernel/triton/flash_decoding.py | 4 ++-- requirements/requirements-infer.txt | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py index dbc5d6cf217a..9b7b27fa1f49 100644 --- a/colossalai/kernel/triton/flash_decoding.py +++ b/colossalai/kernel/triton/flash_decoding.py @@ -1,5 +1,5 @@ +# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py import torch - try: from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1 from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2 @@ -8,8 +8,8 @@ print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8") HAS_LIGHTLLM_KERNEL = False + if HAS_LIGHTLLM_KERNEL: - # adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v): BLOCK_SEQ = 256 batch_size = infer_state.batch_size diff --git a/requirements/requirements-infer.txt b/requirements/requirements-infer.txt index 461dcb23b9fb..3151504df40e 100644 --- a/requirements/requirements-infer.txt +++ b/requirements/requirements-infer.txt @@ -2,6 +2,6 @@ transformers==4.34.0 packaging ninja auto-gptq==0.5.0 -git+https://github.com/ModelTC/lightllm.git@28c1267cfca536b7b4f28e921e03de735b003039 +git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8 git+https://github.com/facebookresearch/xformers.git@main#egg=xformers git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9 From cd877af081a778c7d2e2686e6cd2d023de475a00 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 20 Nov 2023 11:20:36 +0800 Subject: [PATCH 3/5] clean --- colossalai/inference/engine/modeling/llama.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/colossalai/inference/engine/modeling/llama.py b/colossalai/inference/engine/modeling/llama.py index e12b09dfd380..84327d95fbf8 100644 --- a/colossalai/inference/engine/modeling/llama.py +++ b/colossalai/inference/engine/modeling/llama.py @@ -48,7 +48,6 @@ def rotate_half(x): x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) - def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] @@ -73,7 +72,6 @@ def llama_triton_context_attention( attn_output, infer_state.start_loc, infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, infer_state.max_len_in_batch, ) else: @@ -84,7 +82,6 @@ def llama_triton_context_attention( attn_output, infer_state.start_loc, infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, infer_state.max_len_in_batch, ) else: @@ -96,7 +93,6 @@ def llama_triton_context_attention( attn_output, infer_state.start_loc, infer_state.seq_len, - # infer_state.cache_manager.past_key_values_length, infer_state.max_len_in_batch, ) From 78fd05c1d0930bba6b7fb9b555002803c28a8e06 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 20 Nov 2023 11:21:43 +0800 Subject: [PATCH 4/5] clean --- colossalai/inference/engine/modeling/llama.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/colossalai/inference/engine/modeling/llama.py b/colossalai/inference/engine/modeling/llama.py index 84327d95fbf8..b7bc94d0eae0 100644 --- a/colossalai/inference/engine/modeling/llama.py +++ b/colossalai/inference/engine/modeling/llama.py @@ -96,7 +96,6 @@ def llama_triton_context_attention( infer_state.max_len_in_batch, ) - def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num = -1, head_dim = -1): if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1: token_flash_decoding(q = query_states, @@ -460,13 +459,14 @@ def llama_flash_attn_kvcache_forward( ) if HAS_LIGHTLLM_KERNEL: + attn_output = torch.empty_like(query_states) llama_triton_token_attention(query_states = query_states, - attn_output = attn_output, - infer_state = infer_state, - num_key_value_groups = self.num_key_value_groups, - q_head_num = q_len * self.num_heads, - head_dim = self.head_dim) + attn_output = attn_output, + infer_state = infer_state, + num_key_value_groups = self.num_key_value_groups, + q_head_num = q_len * self.num_heads, + head_dim = self.head_dim) else: self.num_heads // self.num_key_value_heads cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] From 70efa2dbb18777079aab9300415588dd487b5c23 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Mon, 20 Nov 2023 11:40:38 +0800 Subject: [PATCH 5/5] delete build.sh --- colossalai/inference/build.sh | 24 ------------------------ examples/inference/hybrid_llama.py | 10 +++++----- 2 files changed, 5 insertions(+), 29 deletions(-) delete mode 100644 colossalai/inference/build.sh diff --git a/colossalai/inference/build.sh b/colossalai/inference/build.sh deleted file mode 100644 index 88a88d692dda..000000000000 --- a/colossalai/inference/build.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env bash - -# install triton -pip install triton -pip install transformers - -# install lightllm and flash-attention -mkdir 3rdParty -cd 3rdParty -git clone https://github.com/ModelTC/lightllm -cd lightllm -git checkout ece7b43f8a6dfa74027adc77c2c176cff28c76c8 -pip install -e . -cd .. - -git clone -recursive https://github.com/Dao-AILab/flash-attention -cd flash-attention -pip install -e . - -cd ../../ - - - - diff --git a/examples/inference/hybrid_llama.py b/examples/inference/hybrid_llama.py index bdfa4e5e8574..1bd34afefb79 100644 --- a/examples/inference/hybrid_llama.py +++ b/examples/inference/hybrid_llama.py @@ -75,11 +75,11 @@ def run_tp_pipeline_inference(rank, world_size, port, args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-p", "--path", type=str, help="Model path", required=True) - parser.add_argument("-tp", "--tp_size", type=int, default=2, help="Tensor parallel size") - parser.add_argument("-pp", "--pp_size", type=int, default=2, help="Tensor parallel size") - parser.add_argument("-b", "--batch_size", type=int, default=8, help="Maximum batch size") - parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length") - parser.add_argument("--max_output_len", type=int, default=16, help="Maximum output length") + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-pp", "--pp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=64, help="Maximum batch size") + parser.add_argument("--max_input_len", type=int, default=512, help="Maximum input length") + parser.add_argument("--max_output_len", type=int, default=256, help="Maximum output length") parser.add_argument("--micro_batch_size", type=int, default=2, help="Micro batch size") args = parser.parse_args()