From a0df081e38feb5cb585ad200380293800afef3fe Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 31 Aug 2023 17:46:22 +0800 Subject: [PATCH] reset shardformer llama --- .../tensor_parallel/modeling/llama.py | 20 +++++++++++-------- colossalai/shardformer/modeling/llama.py | 19 +----------------- 2 files changed, 13 insertions(+), 26 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 7c77785b24e8..1d9e366f99f3 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -4,11 +4,11 @@ import torch from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaModel, - apply_rotary_pos_emb, - LlamaRMSNorm + LlamaAttention, + LlamaDecoderLayer, + LlamaModel, + LlamaRMSNorm, + apply_rotary_pos_emb, ) from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState @@ -17,7 +17,7 @@ from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd try: - from vllm import pos_encoding_ops, layernorm_ops + from vllm import layernorm_ops, pos_encoding_ops rms_norm = layernorm_ops.rms_norm rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox HAS_VLLM_KERNERL = True @@ -255,7 +255,9 @@ def llama_flash_attn_kvcache_forward( if HAS_VLLM_KERNERL: cos_sin_cache = torch.cat((cos, sin), dim=-1) rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache) + key_states = key_states_transposed.transpose(1, 2) else: + # TODO: there are some issues for original rotary_embedding_neox of huggingface query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, position_ids) def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): @@ -313,9 +315,11 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, # return past_key_value as None return attn_output, None, None + def get_llama_vllm_rmsnorm_forward(): - + if HAS_VLLM_KERNERL: + def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): x = hidden_states out = torch.empty_like(x) @@ -330,4 +334,4 @@ def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): return _vllm_rmsnorm_forward else: - return None \ No newline at end of file + return None diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 08220eb73427..f26248d44612 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -19,18 +19,6 @@ from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from colossalai.pipeline.stage_manager import PipelineStageManager -try: - from vllm import pos_encoding_ops - rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox - HAS_VLLM_KERNERL = True -except: - print("fall back to original rotary_embedding_neox of huggingface") - print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") - print( - "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch" - ) - HAS_VLLM_KERNERL = False - class LlamaPipelineForwards: ''' @@ -434,11 +422,7 @@ def forward( cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - if HAS_VLLM_KERNERL: - cos_sin_cache = torch.cat((cos, sin), dim=-1) - rotary_embedding_neox(position_ids, query_states, key_states, self.head_dim, cos_sin_cache) - else: - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention @@ -473,4 +457,3 @@ def forward( return attn_output, None, past_key_value return forward -