Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 73 additions & 68 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
from typing import List, Optional, Tuple

import torch
import numpy as np
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
)
from transformers.models.llama.modeling_llama import LlamaModel, LlamaDecoderLayer, LlamaAttention
import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, apply_rotary_pos_emb

from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
from colossalai.kernel.triton.context_attention import llama_context_attn_fwd
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
from typing import List, Optional, Tuple
from transformers.modeling_outputs import BaseModelOutputWithPast

try:
from vllm import pos_encoding_ops
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
except:
print("fall back to original rotary_embedding_neox of huggingface")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
print(
"if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch"
)
HAS_VLLM_KERNERL = False


class LlamaInferenceForwards:
"""
This class holds forwards for llama inference.
We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM.
"""

@staticmethod
def llama_model_forward(
self: LlamaModel,
Expand All @@ -32,8 +42,8 @@ def llama_model_forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
batch_size = input_ids.shape[0] # input_ids.shape[0]

batch_size = input_ids.shape[0] # input_ids.shape[0]

# infer_state = BatchInferState(batch_size, input_ids.shape[1])
# infer_state.batch_size = batch_size
Expand All @@ -45,13 +55,13 @@ def llama_model_forward(

infer_state = self.infer_state
b_seq_len_numpy = infer_state.seq_len.cpu().numpy()
position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i])
for i in range(len(b_seq_len_numpy))], axis=0)).cuda()
position_ids = torch.from_numpy(
np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0)).cuda()

# this equals
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids).view(position_ids.shape[0], -1)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids).view(position_ids.shape[0], -1)

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# retrieve input_ids and inputs_embeds
Expand All @@ -72,15 +82,16 @@ def llama_model_forward(
past_key_values_length = infer_state.cache_manager.past_key_values_length
# past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length

# FIXME: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage
if use_cache and seq_length != 1:
# NOTE assuem prefill stage
# allocate memory block
infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index)
infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length,
infer_state.context_mem_index)
else:
infer_state.is_context_stage = False
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
Expand All @@ -92,7 +103,9 @@ def llama_model_forward(
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
else:
print(f" *** Encountered allocation non-contiguous")
print(f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}")
print(
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
)
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
Expand All @@ -102,9 +115,10 @@ def llama_model_forward(

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = torch.arange(past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
Expand All @@ -114,13 +128,12 @@ def llama_model_forward(

# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)

attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
attention_mask = torch.ones((batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device)

attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds,
past_key_values_length)

hidden_states = inputs_embeds

Expand All @@ -145,7 +158,7 @@ def llama_model_forward(
)
infer_state.decode_layer_id += 1
hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

Expand All @@ -159,14 +172,14 @@ def llama_model_forward(

if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

@staticmethod
def llama_decoder_layer_forward(
self: LlamaDecoderLayer,
Expand Down Expand Up @@ -212,7 +225,6 @@ def llama_decoder_layer_forward(

return outputs


@staticmethod
def llama_flash_attn_kvcache_forward(
self: LlamaAttention,
Expand All @@ -228,7 +240,7 @@ def llama_flash_attn_kvcache_forward(
assert use_cache is True, "use_cache should be set to True using this llama attention"

bsz, q_len, _ = hidden_states.size()

# TODO might think about better way to handle transposed k and v
# key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head]
# key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
Expand All @@ -237,16 +249,16 @@ def llama_flash_attn_kvcache_forward(
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states_transposed = key_states.transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)

# cos, sin = self.rotary_emb(value_states_transposed, seq_len=kv_seq_len)
cos ,sin = infer_state.position_cos, infer_state.position_sin
cos_sin_cache = torch.cat((cos, sin), dim=-1)

from vllm.pos_encoding_ops import rotary_embedding_neox

rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache)
cos, sin = infer_state.position_cos, infer_state.position_sin

if HAS_VLLM_KERNERL:
cos_sin_cache = torch.cat((cos, sin), dim=-1)
rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache)
else:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, position_ids)
Comment thread
Xu-Kai marked this conversation as resolved.

def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
num_heads = key_buffer.shape[2]
head_dim = key_buffer.shape[3]
Expand All @@ -258,9 +270,11 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index,

# copy key and value calculated in current step to memory manager
if infer_state.is_context_stage:
_copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index, infer_state.cache_manager)
_copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index,
infer_state.cache_manager)
else:
_copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.decode_mem_index, infer_state.cache_manager)
_copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.decode_mem_index,
infer_state.cache_manager)

# this is worse than destcopy
# torch.Tensor.copy_(infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][infer_state.decode_mem_start:infer_state.decode_mem_end, :, :],key_states)
Expand All @@ -269,19 +283,19 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index,
# FIXME might want to revise
# need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now
if infer_state.decode_layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_len # seq_len
if infer_state.decode_layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_len # seq_len

query_states = query_states.transpose(1, 2)

if infer_state.is_context_stage:
# first token generation

# attn_output, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(query_states,
# attn_output, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(query_states,
# key_states,
# value_states,
# value_states,
# 0,
# 1/math.sqrt(self.head_dim),
# 1/math.sqrt(self.head_dim),
# causal,
# False)

Expand All @@ -290,33 +304,24 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index,
# calcu_shape for context_attention_fwd
calcu_shape1 = (-1, self.num_heads, self.head_dim)

llama_context_attn_fwd(query_states.view(calcu_shape1),
key_states.view(calcu_shape1),
value_states.view(calcu_shape1),
attn_output.view(calcu_shape1),
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length)
llama_context_attn_fwd(query_states.view(calcu_shape1), key_states.view(calcu_shape1),
value_states.view(calcu_shape1), attn_output.view(calcu_shape1),
infer_state.start_loc, infer_state.seq_len,
infer_state.cache_manager.past_key_values_length)
else:
# second token and follows
# kv = torch.stack((key_states, value_states), dim=2)
# (batch_size, seqlen, nheads, headdim)
attn_output = torch.empty_like(query_states)

token_attention_fwd(query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,

token_attention_fwd(query_states, infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], attn_output,
infer_state.block_loc, infer_state.start_loc, infer_state.seq_len,
infer_state.cache_manager.past_key_values_length)

attn_output = attn_output.view(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)

# return past_key_value as None
# return past_key_value as None
return attn_output, None, None


52 changes: 26 additions & 26 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,33 @@
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaRMSNorm
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
LlamaRMSNorm,
apply_rotary_pos_emb,
)
from transformers.utils import logging

from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from colossalai.pipeline.stage_manager import PipelineStageManager

try:
from vllm import layernorm_ops, pos_encoding_ops
rms_norm = layernorm_ops.rms_norm
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
rms_norm = layernorm_ops.rms_norm
HAS_VLLM_KERNERL = True
except:
print("fall back to original rotary_embedding_neox of huggingface")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
print(
"if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch"
)
HAS_VLLM_KERNERL = False


class LlamaPipelineForwards:
'''
Expand Down Expand Up @@ -393,20 +415,6 @@ def llama_for_sequence_classification_forward(

def get_llama_flash_attention_forward():

from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention

try:
from vllm import pos_encoding_ops
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
except:
print("fall back to original rotary_embedding_neox of huggingface")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
print("if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch")
HAS_VLLM_KERNERL = False


def forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
Expand All @@ -428,7 +436,7 @@ def forward(
kv_seq_len += past_key_value[0].shape[-2]

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

if HAS_VLLM_KERNERL:
cos_sin_cache = torch.cat((cos, sin), dim=-1)
rotary_embedding_neox(position_ids, query_states, key_states, self.head_dim, cos_sin_cache)
Expand Down Expand Up @@ -471,17 +479,9 @@ def forward(


def get_llama_vllm_rmsnorm_forward():
try:
from vllm import layernorm_ops
rms_norm = layernorm_ops.rms_norm
HAS_VLLM_KERNERL = True
except:
print("please install vllm kernels to install rmsnorm")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
print("if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch")
HAS_VLLM_KERNERL = False


if HAS_VLLM_KERNERL:

def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
x = hidden_states
out = torch.empty_like(x)
Expand Down