From 9f66250e3f8c1a53fdbf0bffb087d4939da55587 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 31 Aug 2023 13:51:28 +0800 Subject: [PATCH 1/3] change import vllm --- .../tensor_parallel/modeling/llama.py | 141 +++++++++--------- colossalai/shardformer/modeling/llama.py | 47 +++--- 2 files changed, 97 insertions(+), 91 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index ce099c61bda7..9647833ade84 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -1,24 +1,34 @@ from typing import List, Optional, Tuple -import torch import numpy as np -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, -) -from transformers.models.llama.modeling_llama import LlamaModel, LlamaDecoderLayer, LlamaAttention +import torch +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel + from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest from colossalai.kernel.triton.context_attention import llama_context_attn_fwd +from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd -from typing import List, Optional, Tuple -from transformers.modeling_outputs import BaseModelOutputWithPast + +try: + from vllm import pos_encoding_ops + rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox + HAS_VLLM_KERNERL = True +except: + print("fall back to original rotary_embedding_neox of huggingface") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + print( + "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch" + ) + HAS_VLLM_KERNERL = False + class LlamaInferenceForwards: """ This class holds forwards for llama inference. We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM. """ - + @staticmethod def llama_model_forward( self: LlamaModel, @@ -32,8 +42,8 @@ def llama_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): - - batch_size = input_ids.shape[0] # input_ids.shape[0] + + batch_size = input_ids.shape[0] # input_ids.shape[0] # infer_state = BatchInferState(batch_size, input_ids.shape[1]) # infer_state.batch_size = batch_size @@ -45,13 +55,13 @@ def llama_model_forward( infer_state = self.infer_state b_seq_len_numpy = infer_state.seq_len.cpu().numpy() - position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i]) - for i in range(len(b_seq_len_numpy))], axis=0)).cuda() - + position_ids = torch.from_numpy( + np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0)).cuda() + # this equals infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) - + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds @@ -72,15 +82,16 @@ def llama_model_forward( past_key_values_length = infer_state.cache_manager.past_key_values_length # past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length - + # FIXME: differentiate with prefill stage # block_loc require different value-assigning method for two different stage if use_cache and seq_length != 1: # NOTE assuem prefill stage # allocate memory block - infer_state.is_context_stage = True # set prefill stage, notify attention layer + infer_state.is_context_stage = True # set prefill stage, notify attention layer infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) - infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index) + infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, + infer_state.context_mem_index) else: infer_state.is_context_stage = False alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) @@ -92,7 +103,9 @@ def llama_model_forward( infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index else: print(f" *** Encountered allocation non-contiguous") - print(f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) infer_state.decode_is_contiguous = False alloc_mem = infer_state.cache_manager.alloc(batch_size) infer_state.decode_mem_index = alloc_mem @@ -102,9 +115,10 @@ def llama_model_forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) + position_ids = torch.arange(past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -114,13 +128,12 @@ def llama_model_forward( # embed positions if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) + attention_mask = torch.ones((batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device) + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, + past_key_values_length) hidden_states = inputs_embeds @@ -145,7 +158,7 @@ def llama_model_forward( ) infer_state.decode_layer_id += 1 hidden_states = layer_outputs[0] - + if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) @@ -159,14 +172,14 @@ def llama_model_forward( if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - + return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) - + @staticmethod def llama_decoder_layer_forward( self: LlamaDecoderLayer, @@ -212,7 +225,6 @@ def llama_decoder_layer_forward( return outputs - @staticmethod def llama_flash_attn_kvcache_forward( self: LlamaAttention, @@ -228,7 +240,7 @@ def llama_flash_attn_kvcache_forward( assert use_cache is True, "use_cache should be set to True using this llama attention" bsz, q_len, _ = hidden_states.size() - + # TODO might think about better way to handle transposed k and v # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head] # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] @@ -237,16 +249,16 @@ def llama_flash_attn_kvcache_forward( key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) key_states_transposed = key_states.transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - + # cos, sin = self.rotary_emb(value_states_transposed, seq_len=kv_seq_len) - cos ,sin = infer_state.position_cos, infer_state.position_sin - - cos_sin_cache = torch.cat((cos, sin), dim=-1) - - from vllm.pos_encoding_ops import rotary_embedding_neox - - rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache) - + cos, sin = infer_state.position_cos, infer_state.position_sin + + if HAS_VLLM_KERNERL: + cos_sin_cache = torch.cat((cos, sin), dim=-1) + rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, position_ids) + def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): num_heads = key_buffer.shape[2] head_dim = key_buffer.shape[3] @@ -258,9 +270,11 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, # copy key and value calculated in current step to memory manager if infer_state.is_context_stage: - _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index, infer_state.cache_manager) + _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index, + infer_state.cache_manager) else: - _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.decode_mem_index, infer_state.cache_manager) + _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.decode_mem_index, + infer_state.cache_manager) # this is worse than destcopy # torch.Tensor.copy_(infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][infer_state.decode_mem_start:infer_state.decode_mem_end, :, :],key_states) @@ -269,19 +283,19 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, # FIXME might want to revise # need some way to record the length of past key values cache # since we won't return past_key_value_cache right now - if infer_state.decode_layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_len # seq_len + if infer_state.decode_layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_len # seq_len query_states = query_states.transpose(1, 2) if infer_state.is_context_stage: # first token generation - # attn_output, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(query_states, + # attn_output, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(query_states, # key_states, - # value_states, + # value_states, # 0, - # 1/math.sqrt(self.head_dim), + # 1/math.sqrt(self.head_dim), # causal, # False) @@ -290,33 +304,24 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, # calcu_shape for context_attention_fwd calcu_shape1 = (-1, self.num_heads, self.head_dim) - llama_context_attn_fwd(query_states.view(calcu_shape1), - key_states.view(calcu_shape1), - value_states.view(calcu_shape1), - attn_output.view(calcu_shape1), - infer_state.start_loc, - infer_state.seq_len, - infer_state.cache_manager.past_key_values_length) + llama_context_attn_fwd(query_states.view(calcu_shape1), key_states.view(calcu_shape1), + value_states.view(calcu_shape1), attn_output.view(calcu_shape1), + infer_state.start_loc, infer_state.seq_len, + infer_state.cache_manager.past_key_values_length) else: # second token and follows # kv = torch.stack((key_states, value_states), dim=2) # (batch_size, seqlen, nheads, headdim) attn_output = torch.empty_like(query_states) - - token_attention_fwd(query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, + + token_attention_fwd(query_states, infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], attn_output, + infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, infer_state.cache_manager.past_key_values_length) attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) - # return past_key_value as None + # return past_key_value as None return attn_output, None, None - - \ No newline at end of file diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 294ab87709c6..bf77e08e3c39 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -7,11 +7,30 @@ CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaRMSNorm +from transformers.models.llama.modeling_llama import ( + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaModel, + LlamaRMSNorm, +) from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager +try: + from vllm import layernorm_ops, pos_encoding_ops + rms_norm = layernorm_ops.rms_norm + rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox + rms_norm = layernorm_ops.rms_norm + HAS_VLLM_KERNERL = True +except: + print("fall back to original rotary_embedding_neox of huggingface") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + print( + "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch" + ) + HAS_VLLM_KERNERL = False + class LlamaPipelineForwards: ''' @@ -394,18 +413,8 @@ def llama_for_sequence_classification_forward( def get_llama_flash_attention_forward(): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention - - try: - from vllm import pos_encoding_ops - rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox - HAS_VLLM_KERNERL = True - except: - print("fall back to original rotary_embedding_neox of huggingface") - print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") - print("if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch") - HAS_VLLM_KERNERL = False - def forward( self: LlamaAttention, @@ -428,7 +437,7 @@ def forward( kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - + if HAS_VLLM_KERNERL: cos_sin_cache = torch.cat((cos, sin), dim=-1) rotary_embedding_neox(position_ids, query_states, key_states, self.head_dim, cos_sin_cache) @@ -471,17 +480,9 @@ def forward( def get_llama_vllm_rmsnorm_forward(): - try: - from vllm import layernorm_ops - rms_norm = layernorm_ops.rms_norm - HAS_VLLM_KERNERL = True - except: - print("please install vllm kernels to install rmsnorm") - print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") - print("if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch") - HAS_VLLM_KERNERL = False - + if HAS_VLLM_KERNERL: + def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): x = hidden_states out = torch.empty_like(x) From 5d00cf17dbedfe46540c1a0209f75ba660976d77 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 31 Aug 2023 14:08:05 +0800 Subject: [PATCH 2/3] import apply_rotary_pos_emb --- colossalai/inference/tensor_parallel/modeling/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 9647833ade84..adb2ad8a0170 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -3,7 +3,7 @@ import numpy as np import torch from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, apply_rotary_pos_emb from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton.context_attention import llama_context_attn_fwd From 7b09536db3248e3187ff9967a51cd10bb1b6b0b9 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 31 Aug 2023 14:13:59 +0800 Subject: [PATCH 3/3] change import location --- colossalai/shardformer/modeling/llama.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index bf77e08e3c39..2224539d273e 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -8,13 +8,16 @@ SequenceClassifierOutputWithPast, ) from transformers.models.llama.modeling_llama import ( + LlamaAttention, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaRMSNorm, + apply_rotary_pos_emb, ) from transformers.utils import logging +from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from colossalai.pipeline.stage_manager import PipelineStageManager try: @@ -412,10 +415,6 @@ def llama_for_sequence_classification_forward( def get_llama_flash_attention_forward(): - from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention - def forward( self: LlamaAttention, hidden_states: torch.Tensor,