diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index adb2ad8a0170..7c77785b24e8 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -3,7 +3,13 @@ import numpy as np import torch from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, apply_rotary_pos_emb +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaModel, + apply_rotary_pos_emb, + LlamaRMSNorm +) from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton.context_attention import llama_context_attn_fwd @@ -11,7 +17,8 @@ from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd try: - from vllm import pos_encoding_ops + from vllm import pos_encoding_ops, layernorm_ops + rms_norm = layernorm_ops.rms_norm rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox HAS_VLLM_KERNERL = True except: @@ -45,14 +52,6 @@ def llama_model_forward( batch_size = input_ids.shape[0] # input_ids.shape[0] - # infer_state = BatchInferState(batch_size, input_ids.shape[1]) - # infer_state.batch_size = batch_size - # # NOTE: dummy implementation here for testing, just assume all inputs same length - # infer_state.block_loc = self.block_loc - # infer_state.start_loc = self.start_loc - # infer_state.seq_len = self.seq_len - # infer_state.max_len_in_batch = self.max_len_in_batch - infer_state = self.infer_state b_seq_len_numpy = infer_state.seq_len.cpu().numpy() position_ids = torch.from_numpy( @@ -276,10 +275,6 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.decode_mem_index, infer_state.cache_manager) - # this is worse than destcopy - # torch.Tensor.copy_(infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][infer_state.decode_mem_start:infer_state.decode_mem_end, :, :],key_states) - # torch.Tensor.copy_(infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][infer_state.decode_mem_start:infer_state.decode_mem_end, :, :],value_states) - # FIXME might want to revise # need some way to record the length of past key values cache # since we won't return past_key_value_cache right now @@ -291,14 +286,6 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, if infer_state.is_context_stage: # first token generation - # attn_output, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(query_states, - # key_states, - # value_states, - # 0, - # 1/math.sqrt(self.head_dim), - # causal, - # False) - attn_output = torch.empty_like(query_states) # calcu_shape for context_attention_fwd @@ -325,3 +312,22 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, # return past_key_value as None return attn_output, None, None + +def get_llama_vllm_rmsnorm_forward(): + + if HAS_VLLM_KERNERL: + def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + x = hidden_states + out = torch.empty_like(x) + rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + + return out + + return _vllm_rmsnorm_forward + else: + return None \ No newline at end of file diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index 997f5fe48a54..c569a0e3163a 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -1,8 +1,10 @@ from functools import partial +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy from ..modeling.llama import LlamaInferenceForwards +from ..modeling.llama import get_llama_vllm_rmsnorm_forward class LlamaModelInferPolicy(LlamaForCausalLMPolicy): @@ -11,7 +13,6 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel policy = super().module_policy() self.shard_config._infer() @@ -36,5 +37,13 @@ def module_policy(self): self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaAttention) + + # TODO: adding rms_norm caused precision issue, fix @tiandiao123 + # infer_forward = get_llama_vllm_rmsnorm_forward() + # if infer_forward is not None: + # method_replacement = {'forward': partial(infer_forward)} + # self.append_or_create_method_replacement(description=method_replacement, + # policy=policy, + # target_key=LlamaRMSNorm) return policy diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 2224539d273e..08220eb73427 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -12,7 +12,6 @@ LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, - LlamaRMSNorm, apply_rotary_pos_emb, ) from transformers.utils import logging @@ -21,10 +20,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager try: - from vllm import layernorm_ops, pos_encoding_ops - rms_norm = layernorm_ops.rms_norm + from vllm import pos_encoding_ops rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox - rms_norm = layernorm_ops.rms_norm HAS_VLLM_KERNERL = True except: print("fall back to original rotary_embedding_neox of huggingface") @@ -477,23 +474,3 @@ def forward( return forward - -def get_llama_vllm_rmsnorm_forward(): - - if HAS_VLLM_KERNERL: - - def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - x = hidden_states - out = torch.empty_like(x) - rms_norm( - out, - x, - self.weight.data, - self.variance_epsilon, - ) - - return out - - return _vllm_rmsnorm_forward - else: - return None