Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,11 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
max_len_in_batch = -1
if isinstance(inputs, (BatchEncoding, dict)):
for i, attn_mask in enumerate(attention_mask):
if isinstance(attn_mask, torch.Tensor):
curr_seq_len = int(torch.sum(attn_mask))
else:
curr_seq_len = int(sum(attn_mask))
curr_seq_len = len(attn_mask)
Comment thread
Xu-Kai marked this conversation as resolved.
# if isinstance(attn_mask, torch.Tensor):
# curr_seq_len = int(torch.sum(attn_mask))
# else:
# curr_seq_len = int(sum(attn_mask))
seq_lengths[i] = curr_seq_len
seq_start_indexes[i] = start_index
start_index += curr_seq_len
Expand All @@ -196,7 +197,6 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch

block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda')
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device
Expand Down
99 changes: 38 additions & 61 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@
import numpy as np
import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaModel,
LlamaRMSNorm,
)
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm

from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton.context_attention import llama_context_attn_fwd
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd

try:
Expand All @@ -28,24 +24,26 @@
)
HAS_VLLM_KERNERL = False


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]

q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
Expand Down Expand Up @@ -75,17 +73,6 @@ def llama_model_forward(
batch_size = input_ids.shape[0] # input_ids.shape[0]

infer_state = self.infer_state
b_seq_len_numpy = infer_state.seq_len.cpu().numpy()

if HAS_VLLM_KERNERL:
position_ids = torch.from_numpy(
np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0)).cuda()

# this equals
infer_state.position_cos = torch.index_select(self._cos_cached, 0,
position_ids).view(position_ids.shape[0], -1)
infer_state.position_sin = torch.index_select(self._sin_cached, 0,
position_ids).view(position_ids.shape[0], -1)

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

Expand Down Expand Up @@ -138,7 +125,6 @@ def llama_model_forward(
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_key_values_length,
Expand All @@ -149,6 +135,17 @@ def llama_model_forward(
else:
position_ids = position_ids.view(-1, seq_length).long()

if infer_state.is_context_stage:

infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1)
else:
seq_len = infer_state.seq_len
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

Expand Down Expand Up @@ -271,50 +268,30 @@ def llama_flash_attn_kvcache_forward(
# key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head]
# key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]

query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states_transposed = key_states.transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)


# NOTE might want to revise
# need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now
if infer_state.decode_layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_len # seq_len

if HAS_VLLM_KERNERL:
# NOTE: fix rotatry embedding precision problem
cos, sin = infer_state.position_cos, infer_state.position_sin

# value_states_transposed = value_states.transpose(1, 2)

# cos, sin = self.rotary_emb(value_states_transposed,
# seq_len=infer_state.cache_manager.past_key_values_length)

cos_sin_cache = torch.cat((cos, sin), dim=-1)

key_states = key_states.view(-1, self.num_heads * self.head_dim)
query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads * self.head_dim)
rotary_embedding_neox(position_ids.squeeze(1), query_states, key_states, self.head_dim, cos_sin_cache)


query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)

else:
# NOTE: there are some issues for original rotary_embedding_neox of huggingface

value_states_transposed = value_states.transpose(1, 2)
cos, sin = self.rotary_emb(value_states_transposed,
seq_len=infer_state.cache_manager.past_key_values_length)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, position_ids)

query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)
cos, sin = infer_state.position_cos, infer_state.position_sin
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )

rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)

def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
return

query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)

if infer_state.is_context_stage:
# first token generation
Expand Down Expand Up @@ -364,6 +341,7 @@ def llama_flash_attn_kvcache_forward(
def get_llama_vllm_rmsnorm_forward():

if HAS_VLLM_KERNERL:

def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
x = hidden_states
out = torch.empty_like(x)
Expand All @@ -379,4 +357,3 @@ def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return _vllm_rmsnorm_forward
else:
return None

93 changes: 93 additions & 0 deletions colossalai/kernel/triton/rotary_embedding_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Adapted from ModelTC https://github.com/ModelTC/lightllm
import torch
import triton
import triton.language as tl


@triton.jit
def _rotary_kernel(
q,
Cos,
Sin,
q_bs_stride,
q_h_stride,
q_d_stride,
cos_bs_stride,
cos_d_stride,
total_len,
HEAD_NUM: tl.constexpr,
BLOCK_HEAD: tl.constexpr,
BLOCK_SEQ: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
current_head_index = tl.program_id(0)
current_seq_index = tl.program_id(1)

current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)

dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)

off_q0 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[
None, :, None] * q_h_stride + dim_range0[None, None, :] * q_d_stride
off_q1 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[
None, :, None] * q_h_stride + dim_range1[None, None, :] * q_d_stride

off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride

q0 = tl.load(q + off_q0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0)
q1 = tl.load(q + off_q1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0)

cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)

out0 = q0 * cos - q1 * sin
out1 = q0 * sin + q1 * cos

tl.store(q + off_q0,
out0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM))
tl.store(q + off_q1,
out1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM))

return


@torch.no_grad()
def rotary_embedding_fwd(q, cos, sin):
total_len = q.shape[0]
head_num = q.shape[1]
head_dim = q.shape[2]
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
BLOCK_HEAD = 4
BLOCK_SEQ = 32
grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))
if head_dim >= 128:
num_warps = 8
else:
num_warps = 4

_rotary_kernel[grid](
q,
cos,
sin,
q.stride(0),
q.stride(1),
q.stride(2),
cos.stride(0),
cos.stride(1),
total_len,
HEAD_NUM=head_num,
BLOCK_HEAD=BLOCK_HEAD,
BLOCK_SEQ=BLOCK_SEQ,
HEAD_DIM=head_dim,
num_warps=num_warps,
num_stages=1,
)
return
Loading