Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 28 additions & 22 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@
import numpy as np
import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, apply_rotary_pos_emb
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaModel,
apply_rotary_pos_emb,
LlamaRMSNorm
)

from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton.context_attention import llama_context_attn_fwd
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd

try:
from vllm import pos_encoding_ops
from vllm import pos_encoding_ops, layernorm_ops
rms_norm = layernorm_ops.rms_norm
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
except:
Expand Down Expand Up @@ -45,14 +52,6 @@ def llama_model_forward(

batch_size = input_ids.shape[0] # input_ids.shape[0]

# infer_state = BatchInferState(batch_size, input_ids.shape[1])
# infer_state.batch_size = batch_size
# # NOTE: dummy implementation here for testing, just assume all inputs same length
# infer_state.block_loc = self.block_loc
# infer_state.start_loc = self.start_loc
# infer_state.seq_len = self.seq_len
# infer_state.max_len_in_batch = self.max_len_in_batch

infer_state = self.infer_state
b_seq_len_numpy = infer_state.seq_len.cpu().numpy()
position_ids = torch.from_numpy(
Expand Down Expand Up @@ -276,10 +275,6 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index,
_copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.decode_mem_index,
infer_state.cache_manager)

# this is worse than destcopy
# torch.Tensor.copy_(infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][infer_state.decode_mem_start:infer_state.decode_mem_end, :, :],key_states)
# torch.Tensor.copy_(infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][infer_state.decode_mem_start:infer_state.decode_mem_end, :, :],value_states)

# FIXME might want to revise
# need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now
Expand All @@ -291,14 +286,6 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index,
if infer_state.is_context_stage:
# first token generation

# attn_output, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(query_states,
# key_states,
# value_states,
# 0,
# 1/math.sqrt(self.head_dim),
# causal,
# False)

attn_output = torch.empty_like(query_states)

# calcu_shape for context_attention_fwd
Expand All @@ -325,3 +312,22 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index,

# return past_key_value as None
return attn_output, None, None

def get_llama_vllm_rmsnorm_forward():

if HAS_VLLM_KERNERL:
def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
x = hidden_states
out = torch.empty_like(x)
rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)

return out

return _vllm_rmsnorm_forward
else:
return None
11 changes: 10 additions & 1 deletion colossalai/inference/tensor_parallel/policies/llama.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from functools import partial
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm

from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy

from ..modeling.llama import LlamaInferenceForwards
from ..modeling.llama import get_llama_vllm_rmsnorm_forward


class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
Expand All @@ -11,7 +13,6 @@ def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
policy = super().module_policy()
self.shard_config._infer()

Expand All @@ -36,5 +37,13 @@ def module_policy(self):
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaAttention)

# TODO: adding rms_norm caused precision issue, fix @tiandiao123
# infer_forward = get_llama_vllm_rmsnorm_forward()
# if infer_forward is not None:
# method_replacement = {'forward': partial(infer_forward)}
# self.append_or_create_method_replacement(description=method_replacement,
# policy=policy,
# target_key=LlamaRMSNorm)

return policy
25 changes: 1 addition & 24 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
LlamaRMSNorm,
apply_rotary_pos_emb,
)
from transformers.utils import logging
Expand All @@ -21,10 +20,8 @@
from colossalai.pipeline.stage_manager import PipelineStageManager

try:
from vllm import layernorm_ops, pos_encoding_ops
rms_norm = layernorm_ops.rms_norm
from vllm import pos_encoding_ops
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
rms_norm = layernorm_ops.rms_norm
HAS_VLLM_KERNERL = True
except:
print("fall back to original rotary_embedding_neox of huggingface")
Expand Down Expand Up @@ -477,23 +474,3 @@ def forward(

return forward


def get_llama_vllm_rmsnorm_forward():

if HAS_VLLM_KERNERL:

def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
x = hidden_states
out = torch.empty_like(x)
rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)

return out

return _vllm_rmsnorm_forward
else:
return None