Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions colossalai/inference/build.sh

This file was deleted.

38 changes: 25 additions & 13 deletions colossalai/inference/engine/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,15 @@
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False

try:
from colossalai.kernel.triton.flash_decoding import token_flash_decoding
HAS_TRITON_FLASH_DECODING_KERNEL = True
except:
print("no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
HAS_TRITON_FLASH_DECODING_KERNEL = False

try:
from flash_attn import flash_attn_with_kvcache

HAS_FLASH_KERNEL = True
except:
HAS_FLASH_KERNEL = False
Expand All @@ -42,7 +48,6 @@ def rotate_half(x):
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
Expand All @@ -67,7 +72,6 @@ def llama_triton_context_attention(
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
Expand All @@ -78,7 +82,6 @@ def llama_triton_context_attention(
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
Expand All @@ -90,13 +93,20 @@ def llama_triton_context_attention(
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)


def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1):
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models"
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num = -1, head_dim = -1):
if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1:
token_flash_decoding(q = query_states,
o_tensor = attn_output,
infer_state = infer_state,
q_head_num = q_head_num,
head_dim = head_dim,
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id])
return

if num_key_value_groups == 1:
token_attention_fwd(
query_states,
Expand All @@ -106,7 +116,6 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
Expand All @@ -118,7 +127,6 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
infer_state.other_kv_index,
)
Expand Down Expand Up @@ -451,10 +459,14 @@ def llama_flash_attn_kvcache_forward(
)

if HAS_LIGHTLLM_KERNEL:

attn_output = torch.empty_like(query_states)
llama_triton_token_attention(
query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups
)
llama_triton_token_attention(query_states = query_states,
attn_output = attn_output,
infer_state = infer_state,
num_key_value_groups = self.num_key_value_groups,
q_head_num = q_len * self.num_heads,
head_dim = self.head_dim)
else:
self.num_heads // self.num_key_value_heads
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
Expand Down
1 change: 1 addition & 0 deletions colossalai/kernel/triton/context_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def _context_flash_attention_kernel(
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
return
else:
# this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11
@triton.jit
def _context_flash_attention_kernel_2(
Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen,
Expand Down
50 changes: 50 additions & 0 deletions colossalai/kernel/triton/flash_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py
import torch
try:
from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1
from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2
HAS_LIGHTLLM_KERNEL = True
except:
print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
HAS_LIGHTLLM_KERNEL = False


if HAS_LIGHTLLM_KERNEL:
def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v):
BLOCK_SEQ = 256
batch_size = infer_state.batch_size
max_len_in_batch = infer_state.max_len_in_batch


calcu_shape1 = (batch_size, q_head_num, head_dim)

if getattr(infer_state, 'mid_o', None) is None:
infer_state.mid_o = torch.empty([batch_size,
q_head_num,
max_len_in_batch // BLOCK_SEQ + 1,
head_dim],
dtype=torch.float32,
device="cuda")
infer_state.mid_o_logexpsum = torch.empty([batch_size,
q_head_num,
max_len_in_batch // BLOCK_SEQ + 1],
dtype=torch.float32,
device="cuda")

mid_o = infer_state.mid_o
mid_o_logexpsum = infer_state.mid_o_logexpsum

flash_decode_stage1(q.view(calcu_shape1),
cache_k,
cache_v,
infer_state.block_loc,
infer_state.seq_len,
infer_state.max_len_in_batch,
mid_o,
mid_o_logexpsum,
BLOCK_SEQ)
flash_decode_stage2(mid_o,
mid_o_logexpsum,
infer_state.seq_len,
o_tensor.view(calcu_shape1),
BLOCK_SEQ)
10 changes: 5 additions & 5 deletions examples/inference/hybrid_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def run_tp_pipeline_inference(rank, world_size, port, args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
parser.add_argument("-tp", "--tp_size", type=int, default=2, help="Tensor parallel size")
parser.add_argument("-pp", "--pp_size", type=int, default=2, help="Tensor parallel size")
parser.add_argument("-b", "--batch_size", type=int, default=8, help="Maximum batch size")
parser.add_argument("--max_input_len", type=int, default=32, help="Maximum input length")
parser.add_argument("--max_output_len", type=int, default=16, help="Maximum output length")
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("-pp", "--pp_size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("-b", "--batch_size", type=int, default=64, help="Maximum batch size")
parser.add_argument("--max_input_len", type=int, default=512, help="Maximum input length")
parser.add_argument("--max_output_len", type=int, default=256, help="Maximum output length")
parser.add_argument("--micro_batch_size", type=int, default=2, help="Micro batch size")

args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-infer.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ transformers==4.34.0
packaging
ninja
auto-gptq==0.5.0
git+https://github.com/ModelTC/lightllm.git@28c1267cfca536b7b4f28e921e03de735b003039
git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8
git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9