From 3e6175ca47447f0c71685dc551ba6a6532006529 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 6 Sep 2023 12:17:53 +0800 Subject: [PATCH 01/22] add chatglm2 --- .../tensor_parallel/modeling/chatglm2.py | 151 ++++++++++++++++++ .../tensor_parallel/modeling/llama.py | 2 +- 2 files changed, 152 insertions(+), 1 deletion(-) create mode 100644 colossalai/inference/tensor_parallel/modeling/chatglm2.py diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py new file mode 100644 index 000000000000..0cd4efd4b87d --- /dev/null +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -0,0 +1,151 @@ +from typing import List, Optional, Tuple + +import numpy as np +import torch +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton.context_attention import llama_context_attn_fwd +from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest +from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd +from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel + +try: + from vllm import layernorm_ops, pos_encoding_ops + rms_norm = layernorm_ops.rms_norm + rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox + HAS_VLLM_KERNERL = True +except: + print("fall back to original rotary_embedding_neox of huggingface") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + print( + "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch" + ) + HAS_VLLM_KERNERL = False + + +class ChatGLM2InferenceForwards: + """ + This class holds forwards for Chatglm2 inference. + We intend to replace the forward methods for ChatGLMModel, ChatGLMEecoderLayer, and ChatGLMAttention. + """ + + @staticmethod + def chatglm_model_forward( + self: ChatGLMModel, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + batch_size, seq_length = input_ids.shape + + infer_state = self.infer_state + seq_length_with_past = seq_length + past_key_values_length = 0 + + # prefill stage at first + if use_cache and seq_length != 1: + infer_state.is_context_stage = True + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, + infer_state.context_mem_index) + else: + infer_state.is_context_stage = False + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + + if infer_state.is_context_stage: + + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1) + else: + seq_len = infer_state.seq_len + infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 07b73a6f4ca6..b946e5a4b9f2 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -101,7 +101,7 @@ def llama_model_forward( # NOTE: differentiate with prefill stage # block_loc require different value-assigning method for two different stage if use_cache and seq_length != 1: - # NOTE assuem prefill stage + # NOTE assume prefill stage # allocate memory block infer_state.is_context_stage = True # set prefill stage, notify attention layer infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) From e992ff74fd677fc0b1bd600f62ffb5f4ccd0aedb Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 6 Sep 2023 13:41:08 +0800 Subject: [PATCH 02/22] add --- .../tensor_parallel/modeling/__init__.py | 3 +- .../tensor_parallel/modeling/chatglm2.py | 65 ++++++++++++++++++- 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/__init__.py b/colossalai/inference/tensor_parallel/modeling/__init__.py index 7a98b033f37e..85b2bdf8c09c 100644 --- a/colossalai/inference/tensor_parallel/modeling/__init__.py +++ b/colossalai/inference/tensor_parallel/modeling/__init__.py @@ -1,4 +1,5 @@ from .bloom import BloomInferenceForwards +from .chatglm2 import ChatGLM2InferenceForwards from .llama import LlamaInferenceForwards -__all__ = ['BloomInferenceForwards', 'LlamaInferenceForwards'] +__all__ = ['BloomInferenceForwards', 'LlamaInferenceForwards', 'ChatGLM2InferenceForwards'] diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index 0cd4efd4b87d..12b72add18a1 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -3,14 +3,17 @@ import numpy as np import torch from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton.context_attention import llama_context_attn_fwd from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( + ChatGLMForConditionalGeneration, + ChatGLMModel, + GLMTransformer, +) try: from vllm import layernorm_ops, pos_encoding_ops @@ -133,7 +136,7 @@ def chatglm_model_forward( kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states, - ) + infer_state=infer_state) if not return_dict: return tuple(v for v in [ @@ -149,3 +152,59 @@ def chatglm_model_forward( hidden_states=all_hidden_states, attentions=all_self_attentions, ) + + @staticmethod + def chatglm_encoder_forward( + self: GLMTransformer, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + ): + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + for index in range(self.num_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer = self._get_layer(index) + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches[index], + use_cache, + ) + else: + layer_ret = layer( + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache, + ) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions From 79a0fa6ecdcd230a610f35406194e647a61eb1f1 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 7 Sep 2023 17:43:58 +0800 Subject: [PATCH 03/22] gather needed kernels --- .../inference/tensor_parallel/engine.py | 11 +- .../tensor_parallel/modeling/_utils.py | 10 + .../tensor_parallel/modeling/chatglm2.py | 355 ++++++++++++-- .../tensor_parallel/policies/__init__.py | 3 +- .../tensor_parallel/policies/chatglm2.py | 59 +++ colossalai/kernel/triton/context_attention.py | 434 ++++++++++++++++-- .../kernel/triton/token_attention_kernel.py | 385 ++++++++++++++++ .../shardformer/policies/auto_policy.py | 5 +- tests/test_infer/test_chatglm2_infer.py | 75 +++ 9 files changed, 1260 insertions(+), 77 deletions(-) create mode 100644 colossalai/inference/tensor_parallel/modeling/_utils.py create mode 100644 colossalai/inference/tensor_parallel/policies/chatglm2.py create mode 100644 tests/test_infer/test_chatglm2_infer.py diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index a5a55702ade0..223320f1cd97 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -15,7 +15,7 @@ DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 -_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM'] +_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM', 'ChatGLMModel'] class TPInferEngine: @@ -59,9 +59,12 @@ def __init__(self, self.dtype = dtype - self.head_dim = model.config.hidden_size // model.config.num_attention_heads - self.head_num = model.config.num_attention_heads - self.layer_num = model.config.num_hidden_layers + self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads + self.head_num = self.model.config.num_attention_heads + + num_hidden_layers = self.model.config.num_hidden_layers if hasattr(self.model.config, "num_hidden_layers") \ + else self.model.config.num_layers + self.layer_num = num_hidden_layers self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.cache_manager = None diff --git a/colossalai/inference/tensor_parallel/modeling/_utils.py b/colossalai/inference/tensor_parallel/modeling/_utils.py new file mode 100644 index 000000000000..282871cabac0 --- /dev/null +++ b/colossalai/inference/tensor_parallel/modeling/_utils.py @@ -0,0 +1,10 @@ +""" +Utils for model inference +""" +from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest + + +def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + return diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index 12b72add18a1..7870fbedcdb2 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -1,18 +1,24 @@ +import os from typing import List, Optional, Tuple +import _utils import numpy as np import torch -from transformers.modeling_outputs import BaseModelOutputWithPast +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton.context_attention import llama_context_attn_fwd +from colossalai.kernel.triton.context_attention import llama2_context_attn_fwd from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd -from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd +from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMForConditionalGeneration, ChatGLMModel, + GLMBlock, GLMTransformer, + SelfAttention, + apply_rotary_pos_emb, ) try: @@ -29,12 +35,109 @@ HAS_VLLM_KERNERL = False +# This func is same as Llama model init_to_get_rotary, we should move them into _utils.py +def _init_to_get_rotary(self, base=10000): + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + try: + ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1)) + assert ntk_alpha >= 1 + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha**(self.head_dim_ / (self.head_dim_ - 2))) #Base change formula + except: + pass + + n_elem = self.config.head_dim_ // 2 + inv_freq = 1.0 / (base**(torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + return + + class ChatGLM2InferenceForwards: """ This class holds forwards for Chatglm2 inference. We intend to replace the forward methods for ChatGLMModel, ChatGLMEecoderLayer, and ChatGLMAttention. """ + @staticmethod + def chatglm_for_conditional_generation_forward( + self: ChatGLMForConditionalGeneration, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + @staticmethod def chatglm_model_forward( self: ChatGLMModel, @@ -52,7 +155,6 @@ def chatglm_model_forward( if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) - batch_size, seq_length = input_ids.shape infer_state = self.infer_state @@ -86,6 +188,7 @@ def chatglm_model_forward( # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + #related to rotary embedding if infer_state.is_context_stage: infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( @@ -96,6 +199,7 @@ def chatglm_model_forward( seq_len = infer_state.seq_len infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) @@ -138,6 +242,11 @@ def chatglm_model_forward( output_hidden_states=output_hidden_states, infer_state=infer_state) + # update indices + # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + if not return_dict: return tuple(v for v in [ hidden_states, @@ -162,40 +271,31 @@ def chatglm_encoder_forward( kv_caches=None, use_cache: Optional[bool] = True, output_hidden_states: Optional[bool] = False, + infer_state: Optional[BatchInferState] = None, ): + if not kv_caches: kv_caches = [None for _ in range(self.num_layers)] presents = () if use_cache else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") - use_cache = False - all_self_attentions = None all_hidden_states = () if output_hidden_states else None + + infer_state.decode_layer_id = 0 + for index in range(self.num_layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer = self._get_layer(index) - if self.gradient_checkpointing and self.training: - layer_ret = torch.utils.checkpoint.checkpoint( - layer, - hidden_states, - attention_mask, - rotary_pos_emb, - kv_caches[index], - use_cache, - ) - else: - layer_ret = layer( - hidden_states, - attention_mask, - rotary_pos_emb, - kv_cache=kv_caches[index], - use_cache=use_cache, - ) + layer = self.layers[index] + + layer_ret = layer( + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache, + infer_state=infer_state, + ) + + infer_state.decode_layer_id += 1 + hidden_states, kv_cache = layer_ret if use_cache: presents = presents + (kv_cache,) @@ -208,3 +308,196 @@ def chatglm_encoder_forward( hidden_states = self.final_layernorm(hidden_states) return hidden_states, presents, all_hidden_states, all_self_attentions + + @staticmethod + def chatglm_glmblock_forward( + self: GLMBlock, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + infer_state: Optional[BatchInferState] = None, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache, + infer_state=infer_state, + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + + return output, kv_cache + + @staticmethod + def chatglm_flash_attn_kvcache_forward( + self: SelfAttention, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + infer_state: Optional[BatchInferState] = None, + ): + assert use_cache is True, "use_cache should be set to True using this chatglm attention" + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view(query_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + )) + key_layer = key_layer.view(key_layer.size()[:-1] + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + )) + value_layer = value_layer.view(value_layer.size()[:-1] + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + )) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + cos, sin = infer_state.position_cos, infer_state.position_sin + + # apply relative positional encoding (rotary embedding) + # if rotary_pos_emb is not None: + # query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + # key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + rotary_embedding_fwd(query_layer.view(-1, self.num_heads, self.head_dim), cos, sin) + rotary_embedding_fwd(key_layer.view(-1, self.num_heads, self.head_dim), cos, sin) + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, + -1, + -1, + self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + -1, + ) + key_layer = key_layer.contiguous().view(key_layer.size()[:2] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + )) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, + -1, + -1, + self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + -1, + ) + value_layer = value_layer.contiguous().view(value_layer.size()[:2] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + )) + + query_layer = query_layer.reshape(-1, self.num_heads, self.head_dim) + key_layer = key_layer.reshape(-1, self.num_heads, self.head_dim) + value_layer = value_layer.reshape(-1, self.num_heads, self.head_dim) + + if infer_state.is_context_stage: + # first token generation: + # copy key and value calculated in current step to memory manager + _utils._copy_kv_to_mem_cache(infer_state.decode_layer_id, key_layer, value_layer, + infer_state.context_mem_index, infer_state.cache_manager) + + attn_output = torch.empty_like(query_layer) + + llama2_context_attn_fwd(query_layer, key_layer, value_layer, attn_output, infer_state.start_loc, + infer_state.seq_len, infer_state.cache_manager.past_key_values_length) + else: + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + cache_k.copy_(key_layer) + cache_v.copy_(value_layer) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + _utils._copy_kv_to_mem_cache(infer_state.decode_layer_id, key_layer, value_layer, + infer_state.decode_mem_index, infer_state.cache_manager) + + # second token and follows + # kv = torch.stack((key_states, value_states), dim=2) + # (batch_size, seqlen, nheads, headdim) + attn_output = torch.empty_like(query_layer) + + Llama2TokenAttentionForwards.token_attn( + query_layer, infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], attn_output, infer_state.block_loc, + infer_state.start_loc, infer_state.seq_len, infer_state.cache_manager.past_key_values_length, + infer_state.other_kv_index) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache diff --git a/colossalai/inference/tensor_parallel/policies/__init__.py b/colossalai/inference/tensor_parallel/policies/__init__.py index 48f8db62c32a..22e0abd2b2ed 100644 --- a/colossalai/inference/tensor_parallel/policies/__init__.py +++ b/colossalai/inference/tensor_parallel/policies/__init__.py @@ -1,4 +1,5 @@ from .bloom import BloomModelInferPolicy +from .chatglm2 import ChatGLM2InferPolicy from .llama import LlamaModelInferPolicy -__all__ = ['BloomModelInferPolicy', 'LlamaModelInferPolicy'] +__all__ = ['BloomModelInferPolicy', 'LlamaModelInferPolicy', 'ChatGLM2InferPolicy'] diff --git a/colossalai/inference/tensor_parallel/policies/chatglm2.py b/colossalai/inference/tensor_parallel/policies/chatglm2.py new file mode 100644 index 000000000000..0e4ec73d49b9 --- /dev/null +++ b/colossalai/inference/tensor_parallel/policies/chatglm2.py @@ -0,0 +1,59 @@ +from functools import partial + +import torch + +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( + ChatGLMModel, + GLMBlock, + GLMTransformer, + SelfAttention, +) +# import colossalai +from colossalai.shardformer.policies.chatglm import ChatGLMModelPolicy + +from ..modeling.chatglm2 import ChatGLM2InferenceForwards, _init_to_get_rotary + +try: + from colossalai.kernel.triton.rms_norm import rmsnorm_forward + HAS_TRITON_RMSNORM = True +except: + print("you should install triton from https://github.com/openai/triton") + HAS_TRITON_RMSNORM = False + + +class ChatGLM2InferPolicy(ChatGLMModelPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + self.shard_config._infer() + + model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward + method_replacement = {'forward': partial(model_infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel) + + encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward + method_replacement = {'forward': partial(encoder_infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=GLMTransformer) + + encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward + method_replacement = {'forward': partial(encoder_layer_infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock) + + attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward + method_replacement = {'forward': partial(attn_infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=SelfAttention) + + # for rmsnorm and others, we need to check the shape + + return policy + + def postprocess(self): + _init_to_get_rotary(self.model) + return self.model diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 38db2048c6a4..9c244bfc2b8b 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -1,5 +1,7 @@ -import torch import math + +import torch + try: import triton import triton.language as tl @@ -8,26 +10,40 @@ HAS_TRITON = False print("please install triton from https://github.com/openai/triton") - if HAS_TRITON: ''' - this function is modified from + this function is modified from https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 ''' + @triton.jit def _context_flash_attention_kernel( - Q, K, V, sm_scale, - B_Start_Loc, B_Seqlen, - TMP, + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, + TMP, alibi_ptr, Out, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, - stride_tmp_b, stride_tmp_h, stride_tmp_s, - # suggtest set-up 64, 128, 256, 512 - BLOCK_M: tl.constexpr, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_tmp_b, + stride_tmp_h, + stride_tmp_s, + # suggtest set-up 64, 128, 256, 512 + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -40,13 +56,14 @@ def _context_flash_attention_kernel( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - - # get batch info + + # get batch info cur_batch_seq_len = tl.load(B_Seqlen + batch_id) cur_batch_start_index = tl.load(B_Start_Loc + batch_id) block_start_loc = BLOCK_M * start_m - - load_p_ptrs = Q + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + + load_p_ptrs = Q + (cur_batch_start_index + + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd @@ -56,7 +73,7 @@ def _context_flash_attention_kernel( m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - + if alibi_ptr is not None: alibi_m = tl.load(alibi_ptr + cur_head) @@ -65,7 +82,8 @@ def _context_flash_attention_kernel( for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) k = tl.load(k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -96,20 +114,21 @@ def _context_flash_attention_kernel( acc = acc * acc_scale[:, None] # update acc v = tl.load(v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0) p = p.to(v.dtype) acc += tl.dot(p, v) # update m_i and l_i l_i = l_i_new m_i = m_i_new - - off_o = (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + + off_o = (cur_batch_start_index + + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return - - + @torch.no_grad() def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None): BLOCK = 128 @@ -129,17 +148,31 @@ def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, al tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) _context_flash_attention_kernel[grid]( - q, k, v, sm_scale, - b_start_loc, b_seq_len, + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, tmp, alibi, o, - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), - tmp.stride(0), tmp.stride(1), tmp.stride(2), - # manually setting this blcok num, we can use tuning config to futher speed-up + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + # manually setting this blcok num, we can use tuning config to futher speed-up BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, @@ -147,7 +180,7 @@ def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, al num_stages=1, ) return - + @torch.no_grad() def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): BLOCK = 128 @@ -166,19 +199,340 @@ def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): num_warps = 4 if Lk <= 64 else 8 # num_warps = 4 _context_flash_attention_kernel[grid]( - q, k, v, sm_scale, b_start_loc, b_seq_len, + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, tmp, None, o, - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), - tmp.stride(0), tmp.stride(1), tmp.stride(2), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1, ) - return \ No newline at end of file + return + + @triton.jit + def _fwd_kernel_latest( + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 + Out, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + kv_group_num, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = (cur_batch_in_all_start_index + + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + + q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0) + # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = (cur_batch_in_all_start_index + + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + + @triton.jit + def _fwd_kernel_old( + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, + TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + Out, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_tmp_b, + stride_tmp_h, + stride_tmp_s, + kv_group_num, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = (cur_batch_in_all_start_index + + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s + # t_ptrs = TMP + offs_m + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = (cur_batch_in_all_start_index + + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + + return + + @torch.no_grad() + def llama2_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): + if triton.__version__ >= "2.1.0": + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / (Lq**0.5) # 计算scale系数 + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + + num_warps = 4 if Lk <= 64 else 8 + _fwd_kernel_latest[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + kv_group_num=kv_group_num, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + elif triton.__version__ == "2.0.0": + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + + tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + # num_warps = 4 + _fwd_kernel_old[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + tmp, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + kv_group_num=kv_group_num, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index c6b25f4abcec..81528cebd70d 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -331,3 +331,388 @@ def token_attention_fwd(q, prob = None return + + +class Llama2TokenAttentionForwards: + + @staticmethod + @triton.jit + def _fwd_kernel( + Logics, + V, + Out, + B_Loc, + B_Start_Loc, + B_Seqlen, + max_input_len, + stride_logic_h, + stride_logic_bs, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_b_loc_b, + stride_b_loc_s, + other_kv_index, # 避免读取到nan的数据 + kv_group_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + off_b_loc = cur_batch * stride_b_loc_b + (max_input_len - cur_batch_seq_len) * stride_b_loc_s + + v_ptrs = V + off_v + + e_max = float("-inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + v_index = tl.load(B_Loc + off_b_loc + (start_n + offs_n) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_seq_len, + other=other_kv_index) + + qk = tl.load(Logics + cur_head * stride_logic_h + + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, + mask=start_n + offs_n < cur_batch_seq_len, + other=float("-inf")) + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + old_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + e_sum = e_sum * old_scale + tl.sum(p, 0) + v = tl.load(v_ptrs + v_index[:, None] * stride_vbs) + acc = acc * old_scale + tl.sum(p[:, None] * v, 0) + e_max = n_e_max + + acc = acc / e_sum + off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + return + + @staticmethod + @torch.no_grad() + def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index): + BLOCK = 64 + batch, head = b_seq_len.shape[0], logics.shape[0] + grid = (batch, head) + kv_group_num = logics.shape[0] // v.shape[1] + + num_warps = 1 + Llama2TokenAttentionForwards._fwd_kernel[grid](logics, + v, + o, + b_loc, + b_start_loc, + b_seq_len, + max_input_len, + logics.stride(0), + logics.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + b_loc.stride(0), + b_loc.stride(1), + other_kv_index, + kv_group_num, + BLOCK_DMODEL=v.shape[-1], + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=3) + return + + @staticmethod + @triton.jit + def _fwd_kernel_token_softmax(Logics, B_Start_Loc, B_Seqlen, Prob_Out, stride_logic_h, stride_logic_bs, + stride_prob_h, stride_prob_bs, BLOCK_SIZE: tl.constexpr): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + col_offsets = tl.arange(0, BLOCK_SIZE) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + row = tl.load(Logics + cur_head * stride_logic_h + + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, + mask=col_offsets < cur_batch_seq_len, + other=-float('inf')).to(tl.float32) + + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + + tl.store(Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs, + softmax_output, + mask=col_offsets < cur_batch_seq_len) + return + + @staticmethod + @torch.no_grad() + def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): + BLOCK_SIZE = triton.next_power_of_2(max_input_len) + batch, head_num = B_Start_Loc.shape[0], Logics.shape[0] + + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + + Llama2TokenAttentionForwards._fwd_kernel_token_softmax[(batch, head_num)]( + Logics, + B_Start_Loc, + B_Seqlen, + Prob_Out, + Logics.stride(0), + Logics.stride(1), + Prob_Out.stride(0), + Prob_Out.stride(1), + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return + + @staticmethod + @triton.jit + def _fwd_kernel_token_att1(Q, K, sm_scale, B_Loc, B_Start_Loc, B_Seqlen, max_input_len, Att_Out, stride_b_loc_b, + stride_b_loc_s, stride_qbs, stride_qh, stride_qd, stride_kbs, stride_kh, stride_kd, + att_stride_h, att_stride_bs, kv_group_num, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_n = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + cur_batch_start_index = max_input_len - cur_batch_seq_len + cur_batch_end_index = max_input_len + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + q = tl.load(Q + off_q + start_mark) + offs_n_new = cur_batch_start_index + offs_n + k_loc = tl.load(B_Loc + stride_b_loc_b * cur_batch + stride_b_loc_s * offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0) + off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd + k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs + tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) + return + + @staticmethod + @torch.no_grad() + def token_att_fwd(q, k, att_out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len): + BLOCK = 32 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] + assert Lq == Lk + assert Lk in {16, 32, 64, 128} + sm_scale = 1.0 / (Lk**0.5) + + batch, head_num = B_Loc.shape[0], q.shape[1] + + grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK)) + kv_group_num = q.shape[1] // k.shape[1] + + num_warps = 4 if Lk <= 64 else 8 + num_warps = 2 + + Llama2TokenAttentionForwards._fwd_kernel_token_att1[grid]( + q, + k, + sm_scale, + B_Loc, + B_Start_Loc, + B_Seqlen, + max_input_len, + att_out, + B_Loc.stride(0), + B_Loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + att_out.stride(0), + att_out.stride(1), + kv_group_num=kv_group_num, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @staticmethod + @triton.jit + def _fwd_kernel_token_att2( + Prob, + V, + Out, + B_Loc, + B_Start_Loc, + B_Seqlen, + max_input_len, # B_Start_Loc 保存的是如果连续存储时候的累加输入和 + stride_b_loc_b, + stride_b_loc_s, + stride_ph, + stride_pbs, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + kv_group_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_kv_head = cur_head // kv_group_num + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_start_index = max_input_len - cur_batch_seq_len + cur_batch_end_index = cur_batch_seq_len + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + v_loc_off = cur_batch * stride_b_loc_b + (cur_batch_start_index + offs_n) * stride_b_loc_s + p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs + v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + for start_n in range(0, cur_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + p_value = tl.load(Prob + p_offs + start_n * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_seq_len, + other=0.0) + v_loc = tl.load(B_Loc + v_loc_off + start_n * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_seq_len, + other=0.0) + v_value = tl.load(V + v_offs + v_loc[:, None] * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0) + acc += tl.sum(p_value[:, None] * v_value, 0) + + acc = acc.to(tl.float16) + off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + return + + @staticmethod + @torch.no_grad() + def token_att_fwd2(prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len): + if triton.__version__ >= "2.1.0": + BLOCK = 128 + else: + BLOCK = 64 + batch, head = B_Loc.shape[0], prob.shape[0] + grid = (batch, head) + num_warps = 4 + dim = v.shape[-1] + + kv_group_num = prob.shape[0] // v.shape[1] + + Llama2TokenAttentionForwards._fwd_kernel_token_att2[grid]( + prob, + v, + out, + B_Loc, + B_Start_Loc, + B_Seqlen, + max_input_len, + B_Loc.stride(0), + B_Loc.stride(1), + prob.stride(0), + prob.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + # this is the interface of llama2 attn forward + @staticmethod + @torch.no_grad() + def token_attn(q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, + other_kv_index): + total_token_num = k.shape[0] + head_num = k.shape[1] + batch_size = kv_cache_seq_len.shape[0] + calcu_shape1 = (batch_size, head_num, k.shape[2]) + + att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") + + Llama2TokenAttentionForwards.token_att_fwd( + q.view(calcu_shape1), + k, + att_m_tensor, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + ) + + if triton.__version__ == "2.0.0": + prob = torch.empty_like(att_m_tensor) + Llama2TokenAttentionForwards.token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, + max_len_in_batch) + att_m_tensor = None + + Llama2TokenAttentionForwards.token_att_fwd2(prob, v, attn_out.view(calcu_shape1), kv_cache_loc, + kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch) + prob = None + + return + + elif triton.__version__ >= "2.1.0": + + Llama2TokenAttentionForwards.token_softmax_reducev_fwd(att_m_tensor, v, attn_out.view(calcu_shape1), + kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, + max_len_in_batch, other_kv_index) + else: + raise Exception("not support triton version") diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 49613ffb37e0..50230047faec 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -142,6 +142,9 @@ class PolicyLocation: PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"), "transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"), + # ChatGLM2 + "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": + PolicyLocation(file_name="chatglm2", class_name="ChatGLM2InferPolicy"), } @@ -186,7 +189,7 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> if policy_location is None: raise NotImplementedError( - f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}" + f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}" ) else: policy = import_policy(policy_location, inference_only) diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py new file mode 100644 index 000000000000..d6293771d1ec --- /dev/null +++ b/tests/test_infer/test_chatglm2_infer.py @@ -0,0 +1,75 @@ +import os + +import numpy as np +import pytest +import torch +import torch.distributed as dist +from packaging import version +from transformers import AutoModel, AutoTokenizer + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +TPSIZE = 1 +BATCH_SIZE = 8 +MAX_INPUT_LEN = 12 +MAX_OUTPUT_LEN = 100 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') + + +def run_chatglm2_test(): + + chatglm2_model_path = "/home/lccjh/data2/lccjh/chatglm2-6b" + assert os.path.isdir(chatglm2_model_path) is True + + tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) + # pad_token_id = 0 + + model = ChatGLMModel.from_pretrained(chatglm2_model_path, pad_token_id=tokenizer.eos_token_id) + #init_to_get_rotary(model.model, base=10000) + model = model.half() + + text = ["how is weather today?", "i am "] + input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True) + + #print("input ids ", input_ids) + infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) + shardformer = ShardFormer(shard_config=shard_config) + + infer_engine.prepare_with_shard_config(shard_config) + infer_engine.shard_model_by(shardformer) + + generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) + outputs = infer_engine.generate(input_ids, generate_kwargs) + print("outputs.shape: ", outputs[0].shape) + + print("outputs: ", outputs[0]) + if not dist.is_initialized() or dist.get_rank() == 0: + for o in outputs: + output_text = tokenizer.decode(o) + print(output_text) + + +def check_chatglm2(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_chatglm2_test() + + +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_chatglm2(): + spawn(check_chatglm2, TPSIZE) + + +if __name__ == "__main__": + test_chatglm2() From 2bca31f44b46b78d25e1b9e3d9d8fdf8cd9e24a4 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 7 Sep 2023 19:03:58 +0800 Subject: [PATCH 04/22] fix some bugs --- .../tensor_parallel/modeling/_utils.py | 2 +- .../tensor_parallel/modeling/chatglm2.py | 48 +++++++++++++------ 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/_utils.py b/colossalai/inference/tensor_parallel/modeling/_utils.py index 282871cabac0..cee418707617 100644 --- a/colossalai/inference/tensor_parallel/modeling/_utils.py +++ b/colossalai/inference/tensor_parallel/modeling/_utils.py @@ -4,7 +4,7 @@ from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest -def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): +def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) return diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index 7870fbedcdb2..e6e10e526c4e 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -1,7 +1,6 @@ import os from typing import List, Optional, Tuple -import _utils import numpy as np import torch from torch.nn import CrossEntropyLoss @@ -18,9 +17,11 @@ GLMBlock, GLMTransformer, SelfAttention, - apply_rotary_pos_emb, + split_tensor_along_last_dim, ) +from ._utils import _copy_kv_to_mem_cache + try: from vllm import layernorm_ops, pos_encoding_ops rms_norm = layernorm_ops.rms_norm @@ -381,7 +382,6 @@ def chatglm_flash_attn_kvcache_forward( # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer = self.query_key_value(hidden_states) - if self.multi_query_attention: (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ @@ -391,6 +391,9 @@ def chatglm_flash_attn_kvcache_forward( ], dim=-1, ) + + print(key_layer.shape) + print(value_layer.shape) query_layer = query_layer.view(query_layer.size()[:-1] + ( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, @@ -403,6 +406,7 @@ def chatglm_flash_attn_kvcache_forward( self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head, )) + else: new_tensor_shape = mixed_x_layer.size()[:-1] + ( self.num_attention_heads_per_partition, @@ -419,8 +423,18 @@ def chatglm_flash_attn_kvcache_forward( # query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) # key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) - rotary_embedding_fwd(query_layer.view(-1, self.num_heads, self.head_dim), cos, sin) - rotary_embedding_fwd(key_layer.view(-1, self.num_heads, self.head_dim), cos, sin) + rotary_embedding_fwd( + query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin) + if self.multi_query_attention: + rotary_embedding_fwd( + key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head), cos, + sin) + else: + rotary_embedding_fwd( + key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, + sin) + + #The shape of key value pair will return to [sq, b , num_heads, num_hidden_size] after rotary embedding, the logic is kept same as original if self.multi_query_attention: key_layer = key_layer.unsqueeze(-2) @@ -447,21 +461,25 @@ def chatglm_flash_attn_kvcache_forward( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, )) + print('key,value shape', key_layer.shape, value_layer.shape) - query_layer = query_layer.reshape(-1, self.num_heads, self.head_dim) - key_layer = key_layer.reshape(-1, self.num_heads, self.head_dim) - value_layer = value_layer.reshape(-1, self.num_heads, self.head_dim) + query_layer = query_layer.reshape(-1, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + key_layer = key_layer.reshape(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + value_layer = value_layer.reshape(-1, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) if infer_state.is_context_stage: # first token generation: # copy key and value calculated in current step to memory manager - _utils._copy_kv_to_mem_cache(infer_state.decode_layer_id, key_layer, value_layer, - infer_state.context_mem_index, infer_state.cache_manager) + _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_layer, value_layer, infer_state.context_mem_index, + infer_state.cache_manager) attn_output = torch.empty_like(query_layer) llama2_context_attn_fwd(query_layer, key_layer, value_layer, attn_output, infer_state.start_loc, infer_state.seq_len, infer_state.cache_manager.past_key_values_length) + print('context stage', attn_output.shape) else: if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly @@ -474,8 +492,8 @@ def chatglm_flash_attn_kvcache_forward( else: # if decode is not contiguous, use triton kernel to copy key and value cache # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head - _utils._copy_kv_to_mem_cache(infer_state.decode_layer_id, key_layer, value_layer, - infer_state.decode_mem_index, infer_state.cache_manager) + _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_layer, value_layer, infer_state.decode_mem_index, + infer_state.cache_manager) # second token and follows # kv = torch.stack((key_states, value_states), dim=2) @@ -491,13 +509,13 @@ def chatglm_flash_attn_kvcache_forward( # ================================== # core attention computation # ================================== - - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + print('attn_out', attn_output.shape) + #context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) # ================= # Output. [sq, b, h] # ================= - output = self.dense(context_layer) + output = self.dense(attn_output) return output, kv_cache From b33da6ab30a9d9ecd2006cc9b2001de00515be32 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 7 Sep 2023 22:39:02 +0800 Subject: [PATCH 05/22] finish context forward --- .../tensor_parallel/modeling/chatglm2.py | 21 ++++++++++--------- .../modeling/chatglm2_6b/modeling_chatglm.py | 2 -- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index e6e10e526c4e..94f48adbd762 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -20,7 +20,7 @@ split_tensor_along_last_dim, ) -from ._utils import _copy_kv_to_mem_cache +from ._utils import copy_kv_to_mem_cache try: from vllm import layernorm_ops, pos_encoding_ops @@ -372,7 +372,7 @@ def chatglm_flash_attn_kvcache_forward( ): assert use_cache is True, "use_cache should be set to True using this chatglm attention" # hidden_states: [sq, b, h] - + batch_size = hidden_states.shape[1] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= @@ -468,12 +468,12 @@ def chatglm_flash_attn_kvcache_forward( key_layer = key_layer.reshape(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) value_layer = value_layer.reshape(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - + print('reshaped', query_layer.shape, key_layer.shape, value_layer.shape) if infer_state.is_context_stage: # first token generation: # copy key and value calculated in current step to memory manager - _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_layer, value_layer, infer_state.context_mem_index, - infer_state.cache_manager) + copy_kv_to_mem_cache(infer_state.decode_layer_id, key_layer, value_layer, infer_state.context_mem_index, + infer_state.cache_manager) attn_output = torch.empty_like(query_layer) @@ -481,6 +481,7 @@ def chatglm_flash_attn_kvcache_forward( infer_state.seq_len, infer_state.cache_manager.past_key_values_length) print('context stage', attn_output.shape) else: + print('token attention') if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ @@ -492,8 +493,8 @@ def chatglm_flash_attn_kvcache_forward( else: # if decode is not contiguous, use triton kernel to copy key and value cache # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head - _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_layer, value_layer, infer_state.decode_mem_index, - infer_state.cache_manager) + copy_kv_to_mem_cache(infer_state.decode_layer_id, key_layer, value_layer, infer_state.decode_mem_index, + infer_state.cache_manager) # second token and follows # kv = torch.stack((key_states, value_states), dim=2) @@ -509,13 +510,13 @@ def chatglm_flash_attn_kvcache_forward( # ================================== # core attention computation # ================================== - print('attn_out', attn_output.shape) + #context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) # ================= - # Output. [sq, b, h] + # Output. [sq, b, h] 7,2,4096 for test # ================= - output = self.dense(attn_output) + output = self.dense(attn_output.view(-1, batch_size, self.projection_size)) return output, kv_cache diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py index a21ee0231422..8267afc91622 100644 --- a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py +++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py @@ -520,7 +520,6 @@ def forward( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, )) - # ================================== # core attention computation # ================================== @@ -530,7 +529,6 @@ def forward( # ================= # Output. [sq, b, h] # ================= - output = self.dense(context_layer) return output, kv_cache From 6e71d697e52d13d29dcff6e4f4a1f81831da7f9d Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 8 Sep 2023 16:18:14 +0800 Subject: [PATCH 06/22] finish context stage --- .../tensor_parallel/modeling/chatglm2.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index 94f48adbd762..3b3c676d1c81 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -475,10 +475,14 @@ def chatglm_flash_attn_kvcache_forward( copy_kv_to_mem_cache(infer_state.decode_layer_id, key_layer, value_layer, infer_state.context_mem_index, infer_state.cache_manager) - attn_output = torch.empty_like(query_layer) + print(query_layer) - llama2_context_attn_fwd(query_layer, key_layer, value_layer, attn_output, infer_state.start_loc, - infer_state.seq_len, infer_state.cache_manager.past_key_values_length) + attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) + + llama2_context_attn_fwd( + query_layer, key_layer, value_layer, + attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), + infer_state.start_loc, infer_state.seq_len, infer_state.cache_manager.past_key_values_length) print('context stage', attn_output.shape) else: print('token attention') @@ -499,8 +503,9 @@ def chatglm_flash_attn_kvcache_forward( # second token and follows # kv = torch.stack((key_states, value_states), dim=2) # (batch_size, seqlen, nheads, headdim) - attn_output = torch.empty_like(query_layer) + attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) + print('in token attn kernel') Llama2TokenAttentionForwards.token_attn( query_layer, infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], attn_output, infer_state.block_loc, @@ -517,6 +522,6 @@ def chatglm_flash_attn_kvcache_forward( # Output. [sq, b, h] 7,2,4096 for test # ================= - output = self.dense(attn_output.view(-1, batch_size, self.projection_size)) + output = self.dense(attn_output).reshape(7, 2, 4096) return output, kv_cache From 95bd76ca4399153cbf81161fea90b2c21748d8b4 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Sat, 9 Sep 2023 00:09:36 +0800 Subject: [PATCH 07/22] fix --- .../tensor_parallel/modeling/chatglm2.py | 34 ++++++++++++++----- .../kernel/triton/token_attention_kernel.py | 8 +++++ 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index 3b3c676d1c81..b8e608f00b09 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -159,9 +159,14 @@ def chatglm_model_forward( batch_size, seq_length = input_ids.shape infer_state = self.infer_state - seq_length_with_past = seq_length past_key_values_length = 0 + # NOT READY FOR PRIME TIME + # dummy but work, revise it + past_key_values_length = infer_state.cache_manager.past_key_values_length + # past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length + past_key_values_length + # prefill stage at first if use_cache and seq_length != 1: infer_state.is_context_stage = True @@ -169,6 +174,7 @@ def chatglm_model_forward( infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index) else: + print('total token num', infer_state.total_token_num) infer_state.is_context_stage = False alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) if alloc_mem is not None: @@ -200,7 +206,8 @@ def chatglm_model_forward( seq_len = infer_state.seq_len infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() + print('max_len_in_batch', infer_state.max_len_in_batch) + infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch].item() if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) @@ -247,6 +254,7 @@ def chatglm_model_forward( # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 + infer_state.cache_manager.past_key_values_length += seq_length if not return_dict: return tuple(v for v in [ @@ -472,6 +480,7 @@ def chatglm_flash_attn_kvcache_forward( if infer_state.is_context_stage: # first token generation: # copy key and value calculated in current step to memory manager + copy_kv_to_mem_cache(infer_state.decode_layer_id, key_layer, value_layer, infer_state.context_mem_index, infer_state.cache_manager) @@ -506,11 +515,20 @@ def chatglm_flash_attn_kvcache_forward( attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) print('in token attn kernel') - Llama2TokenAttentionForwards.token_attn( - query_layer, infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], attn_output, infer_state.block_loc, - infer_state.start_loc, infer_state.seq_len, infer_state.cache_manager.past_key_values_length, - infer_state.other_kv_index) + print('other kv index', infer_state.other_kv_index) + print('query_layer', query_layer.shape) + print('kv', key_layer.shape, value_layer.shape) + print('attn_output', attn_output.shape) + + cache_k = infer_state.cache_manager.key_buffer[ + infer_state.decode_layer_id][:infer_state.decode_mem_end, :, :] + cache_v = infer_state.cache_manager.value_buffer[ + infer_state.decode_layer_id][:infer_state.decode_mem_end, :, :] + + Llama2TokenAttentionForwards.token_attn(query_layer, cache_k, cache_v, attn_output, infer_state.block_loc, + infer_state.start_loc, infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + infer_state.other_kv_index) # ================================== # core attention computation @@ -522,6 +540,6 @@ def chatglm_flash_attn_kvcache_forward( # Output. [sq, b, h] 7,2,4096 for test # ================= - output = self.dense(attn_output).reshape(7, 2, 4096) + output = self.dense(attn_output).reshape(-1, batch_size, self.projection_size) return output, kv_cache diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index 81528cebd70d..3c2cf2f3616f 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -680,9 +680,13 @@ def token_att_fwd2(prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len): @torch.no_grad() def token_attn(q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, other_kv_index): + print('key buffer', k.shape) total_token_num = k.shape[0] + print('total_token_num', total_token_num) head_num = k.shape[1] + print('head num', head_num) batch_size = kv_cache_seq_len.shape[0] + print('batch size', batch_size) calcu_shape1 = (batch_size, head_num, k.shape[2]) att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") @@ -696,6 +700,9 @@ def token_attn(q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq kv_cache_seq_len, max_len_in_batch, ) + print('in kernel') + print('query ', q.shape) + print('att_m_tensor', att_m_tensor.shape) if triton.__version__ == "2.0.0": prob = torch.empty_like(att_m_tensor) @@ -705,6 +712,7 @@ def token_attn(q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq Llama2TokenAttentionForwards.token_att_fwd2(prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch) + print(attn_out.shape) prob = None return From ac6cf9fcd2dd5f523ccf7229d622dc9f9a66958a Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Sat, 9 Sep 2023 22:06:13 +0800 Subject: [PATCH 08/22] add --- .../tensor_parallel/modeling/chatglm2.py | 34 +++---------------- .../kernel/triton/token_attention_kernel.py | 8 ----- 2 files changed, 4 insertions(+), 38 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index b8e608f00b09..8b99bd5eb2d0 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -174,7 +174,6 @@ def chatglm_model_forward( infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index) else: - print('total token num', infer_state.total_token_num) infer_state.is_context_stage = False alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) if alloc_mem is not None: @@ -206,7 +205,6 @@ def chatglm_model_forward( seq_len = infer_state.seq_len infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - print('max_len_in_batch', infer_state.max_len_in_batch) infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch].item() if inputs_embeds is None: @@ -297,7 +295,6 @@ def chatglm_encoder_forward( layer_ret = layer( hidden_states, attention_mask, - rotary_pos_emb, kv_cache=kv_caches[index], use_cache=use_cache, infer_state=infer_state, @@ -323,7 +320,6 @@ def chatglm_glmblock_forward( self: GLMBlock, hidden_states, attention_mask, - rotary_pos_emb, kv_cache=None, use_cache=True, infer_state: Optional[BatchInferState] = None, @@ -336,7 +332,6 @@ def chatglm_glmblock_forward( attention_output, kv_cache = self.self_attention( layernorm_output, attention_mask, - rotary_pos_emb, kv_cache=kv_cache, use_cache=use_cache, infer_state=infer_state, @@ -373,7 +368,6 @@ def chatglm_flash_attn_kvcache_forward( self: SelfAttention, hidden_states, attention_mask, - rotary_pos_emb, kv_cache=None, use_cache=True, infer_state: Optional[BatchInferState] = None, @@ -399,9 +393,6 @@ def chatglm_flash_attn_kvcache_forward( ], dim=-1, ) - - print(key_layer.shape) - print(value_layer.shape) query_layer = query_layer.view(query_layer.size()[:-1] + ( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, @@ -426,11 +417,6 @@ def chatglm_flash_attn_kvcache_forward( cos, sin = infer_state.position_cos, infer_state.position_sin - # apply relative positional encoding (rotary embedding) - # if rotary_pos_emb is not None: - # query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) - # key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) - rotary_embedding_fwd( query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin) if self.multi_query_attention: @@ -469,32 +455,26 @@ def chatglm_flash_attn_kvcache_forward( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, )) - print('key,value shape', key_layer.shape, value_layer.shape) - + # reshape q k v to [bsz*sql, num_heads, head_dim] query_layer = query_layer.reshape(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) key_layer = key_layer.reshape(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) value_layer = value_layer.reshape(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - print('reshaped', query_layer.shape, key_layer.shape, value_layer.shape) if infer_state.is_context_stage: # first token generation: # copy key and value calculated in current step to memory manager copy_kv_to_mem_cache(infer_state.decode_layer_id, key_layer, value_layer, infer_state.context_mem_index, infer_state.cache_manager) - print(query_layer) - attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) llama2_context_attn_fwd( query_layer, key_layer, value_layer, attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), infer_state.start_loc, infer_state.seq_len, infer_state.cache_manager.past_key_values_length) - print('context stage', attn_output.shape) else: - print('token attention') if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ @@ -514,12 +494,6 @@ def chatglm_flash_attn_kvcache_forward( # (batch_size, seqlen, nheads, headdim) attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) - print('in token attn kernel') - print('other kv index', infer_state.other_kv_index) - print('query_layer', query_layer.shape) - print('kv', key_layer.shape, value_layer.shape) - print('attn_output', attn_output.shape) - cache_k = infer_state.cache_manager.key_buffer[ infer_state.decode_layer_id][:infer_state.decode_mem_end, :, :] cache_v = infer_state.cache_manager.value_buffer[ @@ -531,13 +505,13 @@ def chatglm_flash_attn_kvcache_forward( infer_state.other_kv_index) # ================================== - # core attention computation + # core attention computation is replaced by triton kernel # ================================== - #context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + # context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) # ================= - # Output. [sq, b, h] 7,2,4096 for test + # Output. [sq, b, h] # ================= output = self.dense(attn_output).reshape(-1, batch_size, self.projection_size) diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index 3c2cf2f3616f..81528cebd70d 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -680,13 +680,9 @@ def token_att_fwd2(prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len): @torch.no_grad() def token_attn(q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, other_kv_index): - print('key buffer', k.shape) total_token_num = k.shape[0] - print('total_token_num', total_token_num) head_num = k.shape[1] - print('head num', head_num) batch_size = kv_cache_seq_len.shape[0] - print('batch size', batch_size) calcu_shape1 = (batch_size, head_num, k.shape[2]) att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") @@ -700,9 +696,6 @@ def token_attn(q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq kv_cache_seq_len, max_len_in_batch, ) - print('in kernel') - print('query ', q.shape) - print('att_m_tensor', att_m_tensor.shape) if triton.__version__ == "2.0.0": prob = torch.empty_like(att_m_tensor) @@ -712,7 +705,6 @@ def token_attn(q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq Llama2TokenAttentionForwards.token_att_fwd2(prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch) - print(attn_out.shape) prob = None return From dae60a8ca3a6af8213b09490c0184aed57e9b50e Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 11 Sep 2023 16:25:50 +0800 Subject: [PATCH 09/22] pause --- .../tensor_parallel/modeling/__init__.py | 2 ++ .../tensor_parallel/modeling/chatglm2.py | 24 +++++++++++++------ colossalai/kernel/triton/context_attention.py | 5 +++- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/__init__.py b/colossalai/inference/tensor_parallel/modeling/__init__.py index 85b2bdf8c09c..166aab659526 100644 --- a/colossalai/inference/tensor_parallel/modeling/__init__.py +++ b/colossalai/inference/tensor_parallel/modeling/__init__.py @@ -1,3 +1,5 @@ +import _utils + from .bloom import BloomInferenceForwards from .chatglm2 import ChatGLM2InferenceForwards from .llama import LlamaInferenceForwards diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index 8b99bd5eb2d0..a353543ff6a6 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -303,6 +303,7 @@ def chatglm_encoder_forward( infer_state.decode_layer_id += 1 hidden_states, kv_cache = layer_ret + print(hidden_states[0][0]) if use_cache: presents = presents + (kv_cache,) @@ -327,7 +328,9 @@ def chatglm_glmblock_forward( # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. + print('glm block', hidden_states[0][0]) layernorm_output = self.input_layernorm(hidden_states) + print('after layernorm', layernorm_output[0][0]) # Self attention. attention_output, kv_cache = self.self_attention( layernorm_output, @@ -336,19 +339,20 @@ def chatglm_glmblock_forward( use_cache=use_cache, infer_state=infer_state, ) - + print(attention_output.shape) + print('attn output', attention_output[0][0]) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states - + print('attn output', attention_output[0][1]) layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) layernorm_input = residual + layernorm_input - + print('layernorm input', layernorm_input[0][0]) # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) - + print('layernorm output', layernorm_output[0][0]) # MLP. mlp_output = self.mlp(layernorm_output) @@ -467,13 +471,18 @@ def chatglm_flash_attn_kvcache_forward( copy_kv_to_mem_cache(infer_state.decode_layer_id, key_layer, value_layer, infer_state.context_mem_index, infer_state.cache_manager) - print(query_layer) + if infer_state.decode_layer_id <= 2: + print(torch.isnan(query_layer).any()) + print(torch.isnan(key_layer).any()) + print(torch.isnan(value_layer).any()) attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) llama2_context_attn_fwd( query_layer, key_layer, value_layer, attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), infer_state.start_loc, infer_state.seq_len, infer_state.cache_manager.past_key_values_length) + if infer_state.decode_layer_id <= 2: + print('attn output', attn_output) else: if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly @@ -498,7 +507,8 @@ def chatglm_flash_attn_kvcache_forward( infer_state.decode_layer_id][:infer_state.decode_mem_end, :, :] cache_v = infer_state.cache_manager.value_buffer[ infer_state.decode_layer_id][:infer_state.decode_mem_end, :, :] - + #print(infer_state.decode_layer_id) + #print('cache k,v',cache_k[0],cache_v[0]) Llama2TokenAttentionForwards.token_attn(query_layer, cache_k, cache_v, attn_output, infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, infer_state.cache_manager.past_key_values_length, @@ -515,5 +525,5 @@ def chatglm_flash_attn_kvcache_forward( # ================= output = self.dense(attn_output).reshape(-1, batch_size, self.projection_size) - + print('after dense', output[0][0]) return output, kv_cache diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 9c244bfc2b8b..897071c1f86c 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -488,7 +488,11 @@ def llama2_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): num_stages=1, ) return + elif triton.__version__ == "2.0.0": + #print('query',q[0]) + # print('key',k[0]) + # print('value',v[0]) BLOCK = 128 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] @@ -500,7 +504,6 @@ def llama2_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): kv_group_num = q.shape[1] // k.shape[1] grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) - tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 # num_warps = 4 From 0c45002ac95f8de304a5f7c13181deafcc1e18a9 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 11 Sep 2023 16:50:03 +0800 Subject: [PATCH 10/22] add --- colossalai/inference/tensor_parallel/modeling/chatglm2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index a353543ff6a6..ec68ac9159a9 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -483,6 +483,7 @@ def chatglm_flash_attn_kvcache_forward( infer_state.start_loc, infer_state.seq_len, infer_state.cache_manager.past_key_values_length) if infer_state.decode_layer_id <= 2: print('attn output', attn_output) + print(torch.isnan(attn_output).any()) else: if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly From 20b2df0f5294f98aba87ddb0507ba3747c40e367 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 12 Sep 2023 18:52:00 +0800 Subject: [PATCH 11/22] fix bugs --- colossalai/inference/tensor_parallel/engine.py | 9 ++++----- .../tensor_parallel/policies/chatglm2.py | 2 +- tests/test_infer/test_chatglm2_infer.py | 17 ++++++++--------- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 223320f1cd97..22ea341ff06d 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -59,11 +59,10 @@ def __init__(self, self.dtype = dtype - self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads - self.head_num = self.model.config.num_attention_heads - - num_hidden_layers = self.model.config.num_hidden_layers if hasattr(self.model.config, "num_hidden_layers") \ - else self.model.config.num_layers + self.head_dim = model.config.hidden_size // model.config.num_attention_heads + self.head_num = model.config.num_attention_heads + num_hidden_layers = model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") \ + else model.config.num_layers self.layer_num = num_hidden_layers self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config diff --git a/colossalai/inference/tensor_parallel/policies/chatglm2.py b/colossalai/inference/tensor_parallel/policies/chatglm2.py index 0e4ec73d49b9..cb3c464ef9e3 100644 --- a/colossalai/inference/tensor_parallel/policies/chatglm2.py +++ b/colossalai/inference/tensor_parallel/policies/chatglm2.py @@ -9,7 +9,7 @@ SelfAttention, ) # import colossalai -from colossalai.shardformer.policies.chatglm import ChatGLMModelPolicy +from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy from ..modeling.chatglm2 import ChatGLM2InferenceForwards, _init_to_get_rotary diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index d6293771d1ec..74269f8a58b5 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -23,7 +23,10 @@ CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') -def run_chatglm2_test(): +@parameterize('test_config', [{ + 'tp_size': TPSIZE, +}]) +def run_chatglm2_test(test_config): chatglm2_model_path = "/home/lccjh/data2/lccjh/chatglm2-6b" assert os.path.isdir(chatglm2_model_path) is True @@ -37,17 +40,13 @@ def run_chatglm2_test(): text = ["how is weather today?", "i am "] input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True) - + shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, + inference_only=True) #print("input ids ", input_ids) - infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) - shardformer = ShardFormer(shard_config=shard_config) - - infer_engine.prepare_with_shard_config(shard_config) - infer_engine.shard_model_by(shardformer) + infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) - outputs = infer_engine.generate(input_ids, generate_kwargs) + outputs = infer_engine.generate(input_ids, **generate_kwargs) print("outputs.shape: ", outputs[0].shape) print("outputs: ", outputs[0]) From 58e0a6b7436e4288bbe969aa89c79f8723103a3a Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 14 Sep 2023 16:46:16 +0800 Subject: [PATCH 12/22] finish chatglm --- .../inference/tensor_parallel/engine.py | 9 +- .../tensor_parallel/modeling/chatglm2.py | 201 +++++++++--------- .../tensor_parallel/policies/chatglm2.py | 19 ++ .../shardformer/policies/auto_policy.py | 2 + tests/test_infer/test_chatglm2_infer.py | 4 +- 5 files changed, 137 insertions(+), 98 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 22ea341ff06d..8c7712376e0e 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -15,7 +15,9 @@ DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 -_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM', 'ChatGLMModel'] +_supported_models = [ + 'LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration' +] class TPInferEngine: @@ -252,6 +254,11 @@ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch model = self.model.transformer setattr(model, 'infer_state', batch_infer_state) + # outputs = self.model.forward(input_tokens['input_ids'], attention_mask=input_tokens['attention_mask']) + # outputs = self.model.forward(input_tokens['input_ids'][:, 0].unsqueeze(1), attention_mask=input_tokens['attention_mask']) + # outputs = self.model.forward(input_tokens['input_ids'][:, 1].unsqueeze(1), attention_mask=input_tokens['attention_mask']) + + # FOR test chatglm2 outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False) # NOTE In future development, we're going to let the scheduler to handle the cache, diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index ec68ac9159a9..0c7cda681dbb 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -72,6 +72,28 @@ def _init_to_get_rotary(self, base=10000): return +def get_masks(self, input_ids, past_length, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + if past_length: + full_attention_mask = torch.cat( + ( + torch.ones(batch_size, seq_length, past_length, device=input_ids.device), + full_attention_mask, + ), + dim=-1, + ) + + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + class ChatGLM2InferenceForwards: """ This class holds forwards for Chatglm2 inference. @@ -95,17 +117,74 @@ def chatglm_for_conditional_generation_forward( ): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + infer_state = self.infer_state - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + + # NOT READY FOR PRIME TIME + # dummy but work, revise it + past_key_values_length = infer_state.cache_manager.past_key_values_length + seq_length_with_past = seq_length + past_key_values_length + infer_state.seq_length_with_past = seq_length_with_past + + # prefill stage at first + if use_cache and seq_length != 1: + infer_state.is_context_stage = True + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, + infer_state.context_mem_index) + else: + infer_state.is_context_stage = False + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + + #related to rotary embedding + if infer_state.is_context_stage: + + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1) + else: + seq_len = infer_state.seq_len + infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch].item() + + transformer_outputs = self.transformer(input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + infer_state=infer_state) hidden_states = transformer_outputs[0] if return_last_logit: @@ -151,62 +230,13 @@ def chatglm_model_forward( use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + infer_state: BatchInferState = None, ): output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) batch_size, seq_length = input_ids.shape - - infer_state = self.infer_state - past_key_values_length = 0 - - # NOT READY FOR PRIME TIME - # dummy but work, revise it - past_key_values_length = infer_state.cache_manager.past_key_values_length - # past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length + past_key_values_length - - # prefill stage at first - if use_cache and seq_length != 1: - infer_state.is_context_stage = True - infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) - infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, - infer_state.context_mem_index) - else: - infer_state.is_context_stage = False - alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) - if alloc_mem is not None: - infer_state.decode_is_contiguous = True - infer_state.decode_mem_index = alloc_mem[0] - infer_state.decode_mem_start = alloc_mem[1] - infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index - else: - print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) - infer_state.decode_is_contiguous = False - alloc_mem = infer_state.cache_manager.alloc(batch_size) - infer_state.decode_mem_index = alloc_mem - # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index - - #related to rotary embedding - if infer_state.is_context_stage: - - infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1) - else: - seq_len = infer_state.seq_len - infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch].item() - if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) @@ -225,10 +255,12 @@ def chatglm_model_forward( ], dim=-1, ) - if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + full_attention_mask = get_masks(self, + input_ids, + infer_state.cache_manager.past_key_values_length, + padding_mask=attention_mask) # Rotary positional embeddings rotary_pos_emb = self.rotary_pos_emb(self.seq_length) @@ -303,7 +335,6 @@ def chatglm_encoder_forward( infer_state.decode_layer_id += 1 hidden_states, kv_cache = layer_ret - print(hidden_states[0][0]) if use_cache: presents = presents + (kv_cache,) @@ -328,9 +359,7 @@ def chatglm_glmblock_forward( # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. - print('glm block', hidden_states[0][0]) layernorm_output = self.input_layernorm(hidden_states) - print('after layernorm', layernorm_output[0][0]) # Self attention. attention_output, kv_cache = self.self_attention( layernorm_output, @@ -339,20 +368,15 @@ def chatglm_glmblock_forward( use_cache=use_cache, infer_state=infer_state, ) - print(attention_output.shape) - print('attn output', attention_output[0][0]) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states - print('attn output', attention_output[0][1]) layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) layernorm_input = residual + layernorm_input - print('layernorm input', layernorm_input[0][0]) # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) - print('layernorm output', layernorm_output[0][0]) # MLP. mlp_output = self.mlp(layernorm_output) @@ -382,9 +406,6 @@ def chatglm_flash_attn_kvcache_forward( # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= - # ===================== - # Query, Key, and Value - # ===================== # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer = self.query_key_value(hidden_states) @@ -459,6 +480,7 @@ def chatglm_flash_attn_kvcache_forward( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, )) + # reshape q k v to [bsz*sql, num_heads, head_dim] query_layer = query_layer.reshape(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) @@ -471,19 +493,15 @@ def chatglm_flash_attn_kvcache_forward( copy_kv_to_mem_cache(infer_state.decode_layer_id, key_layer, value_layer, infer_state.context_mem_index, infer_state.cache_manager) - if infer_state.decode_layer_id <= 2: - print(torch.isnan(query_layer).any()) - print(torch.isnan(key_layer).any()) - print(torch.isnan(value_layer).any()) + attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) + # NOTE: no bug in context attn fwd (del it ) llama2_context_attn_fwd( query_layer, key_layer, value_layer, attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), - infer_state.start_loc, infer_state.seq_len, infer_state.cache_manager.past_key_values_length) - if infer_state.decode_layer_id <= 2: - print('attn output', attn_output) - print(torch.isnan(attn_output).any()) + infer_state.start_loc, infer_state.seq_len, infer_state.seq_length_with_past) + else: if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly @@ -500,31 +518,24 @@ def chatglm_flash_attn_kvcache_forward( infer_state.cache_manager) # second token and follows - # kv = torch.stack((key_states, value_states), dim=2) - # (batch_size, seqlen, nheads, headdim) attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) cache_k = infer_state.cache_manager.key_buffer[ infer_state.decode_layer_id][:infer_state.decode_mem_end, :, :] cache_v = infer_state.cache_manager.value_buffer[ infer_state.decode_layer_id][:infer_state.decode_mem_end, :, :] - #print(infer_state.decode_layer_id) - #print('cache k,v',cache_k[0],cache_v[0]) + + # ================================== + # core attention computation is replaced by triton kernel + # ================================== Llama2TokenAttentionForwards.token_attn(query_layer, cache_k, cache_v, attn_output, infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, - infer_state.other_kv_index) - - # ================================== - # core attention computation is replaced by triton kernel - # ================================== - - # context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + infer_state.seq_length_with_past, infer_state.other_kv_index) + #print('after attention',torch.isnan(attn_output).any()) # ================= - # Output. [sq, b, h] + # Output:[sq, b, h] ,it is kept same as original. # ================= output = self.dense(attn_output).reshape(-1, batch_size, self.projection_size) - print('after dense', output[0][0]) return output, kv_cache diff --git a/colossalai/inference/tensor_parallel/policies/chatglm2.py b/colossalai/inference/tensor_parallel/policies/chatglm2.py index cb3c464ef9e3..f0bd50f5b54b 100644 --- a/colossalai/inference/tensor_parallel/policies/chatglm2.py +++ b/colossalai/inference/tensor_parallel/policies/chatglm2.py @@ -3,6 +3,7 @@ import torch from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( + ChatGLMForConditionalGeneration, ChatGLMModel, GLMBlock, GLMTransformer, @@ -57,3 +58,21 @@ def module_policy(self): def postprocess(self): _init_to_get_rotary(self.model) return self.model + + +class ChatGLM2ForConditionalGenerationInferPolicy(ChatGLM2InferPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + model_infer_forward = ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward + method_replacement = {'forward': partial(model_infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=ChatGLMForConditionalGeneration) + return policy + + def postprocess(self): + return super().postprocess() diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 50230047faec..e0a1278b4d30 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -145,6 +145,8 @@ class PolicyLocation: # ChatGLM2 "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(file_name="chatglm2", class_name="ChatGLM2InferPolicy"), + "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": + PolicyLocation(file_name="chatglm2", class_name="ChatGLM2ForConditionalGenerationInferPolicy"), } diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index 74269f8a58b5..8f15f19929a4 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -12,7 +12,7 @@ from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -34,7 +34,7 @@ def run_chatglm2_test(test_config): tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) # pad_token_id = 0 - model = ChatGLMModel.from_pretrained(chatglm2_model_path, pad_token_id=tokenizer.eos_token_id) + model = ChatGLMForConditionalGeneration.from_pretrained(chatglm2_model_path, pad_token_id=tokenizer.eos_token_id) #init_to_get_rotary(model.model, base=10000) model = model.half() From 185ff154bf5b0059a0e8f9674a9341136f474e6e Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 18 Sep 2023 16:50:28 +0800 Subject: [PATCH 13/22] fix bug --- .../tensor_parallel/modeling/chatglm2.py | 25 ++--- .../kernel/triton/rotary_embedding_kernel.py | 97 +++++++++++++++++++ .../kernel/triton/token_attention_kernel.py | 10 +- tests/test_infer/test_chatglm2_infer.py | 8 +- .../triton/test_llama2_token_attn.py | 69 +++++++++++++ 5 files changed, 186 insertions(+), 23 deletions(-) create mode 100644 tests/test_infer_ops/triton/test_llama2_token_attn.py diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index 0c7cda681dbb..fb8e24fd7e01 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -9,7 +9,7 @@ from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton.context_attention import llama2_context_attn_fwd from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest -from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd +from colossalai.kernel.triton.rotary_embedding_kernel import Llama2Forwards from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMForConditionalGeneration, @@ -61,7 +61,6 @@ def _init_to_get_rotary(self, base=10000): base = base * (ntk_alpha**(self.head_dim_ / (self.head_dim_ - 2))) #Base change formula except: pass - n_elem = self.config.head_dim_ // 2 inv_freq = 1.0 / (base**(torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor @@ -174,7 +173,7 @@ def chatglm_for_conditional_generation_forward( seq_len = infer_state.seq_len infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch].item() + infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() transformer_outputs = self.transformer(input_ids=input_ids, position_ids=position_ids, @@ -237,6 +236,7 @@ def chatglm_model_forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) batch_size, seq_length = input_ids.shape + if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) @@ -284,6 +284,7 @@ def chatglm_model_forward( # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 + infer_state.max_len_in_batch += 1 infer_state.cache_manager.past_key_values_length += seq_length if not return_dict: @@ -320,7 +321,6 @@ def chatglm_encoder_forward( all_hidden_states = () if output_hidden_states else None infer_state.decode_layer_id = 0 - for index in range(self.num_layers): layer = self.layers[index] @@ -388,7 +388,6 @@ def chatglm_glmblock_forward( output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) output = residual + output - return output, kv_cache @staticmethod @@ -408,7 +407,9 @@ def chatglm_flash_attn_kvcache_forward( # ================================================= # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + if self.multi_query_attention: (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ @@ -442,14 +443,14 @@ def chatglm_flash_attn_kvcache_forward( cos, sin = infer_state.position_cos, infer_state.position_sin - rotary_embedding_fwd( + Llama2Forwards.rotary_emb_fwd( query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin) if self.multi_query_attention: - rotary_embedding_fwd( + Llama2Forwards.rotary_emb_fwd( key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head), cos, sin) else: - rotary_embedding_fwd( + Llama2Forwards.rotary_emb_fwd( key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin) @@ -481,7 +482,7 @@ def chatglm_flash_attn_kvcache_forward( self.hidden_size_per_attention_head, )) - # reshape q k v to [bsz*sql, num_heads, head_dim] + # reshape q k v to [bsz*sql, num_heads, head_dim] 2*1 ,32 ,128 query_layer = query_layer.reshape(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) key_layer = key_layer.reshape(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) @@ -509,6 +510,7 @@ def chatglm_flash_attn_kvcache_forward( infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + # 2, 32, 128 cache_k.copy_(key_layer) cache_v.copy_(value_layer) else: @@ -519,18 +521,17 @@ def chatglm_flash_attn_kvcache_forward( # second token and follows attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) - cache_k = infer_state.cache_manager.key_buffer[ infer_state.decode_layer_id][:infer_state.decode_mem_end, :, :] cache_v = infer_state.cache_manager.value_buffer[ infer_state.decode_layer_id][:infer_state.decode_mem_end, :, :] - # ================================== # core attention computation is replaced by triton kernel # ================================== Llama2TokenAttentionForwards.token_attn(query_layer, cache_k, cache_v, attn_output, infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - infer_state.seq_length_with_past, infer_state.other_kv_index) + infer_state.max_len_in_batch, infer_state.other_kv_index) + #print('after attention',torch.isnan(attn_output).any()) # ================= diff --git a/colossalai/kernel/triton/rotary_embedding_kernel.py b/colossalai/kernel/triton/rotary_embedding_kernel.py index d9d1b2bcf026..a63e2ea95f63 100644 --- a/colossalai/kernel/triton/rotary_embedding_kernel.py +++ b/colossalai/kernel/triton/rotary_embedding_kernel.py @@ -91,3 +91,100 @@ def rotary_embedding_fwd(q, cos, sin): num_stages=1, ) return + + +class Llama2Forwards: + + @staticmethod + @triton.jit + def _rotary_kernel( + Q, + Cos, + Sin, + stride_qbs, + stride_qh, + stride_qd, + stride_cosbs, + stride_cosd, + stride_sinbs, + stride_sind, + max_total_len, + H, # N_CTX 代表要计算的上下文长度 + BLOCK_HEAD: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ): + + cur_head_index = tl.program_id(0) + cur_seq_index = tl.program_id(1) + + cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + + dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 + dim_range1 = dim_range0 + 1 + off_q0 = cur_seq_range[:, None, None] * stride_qbs + cur_head_range[None, :, None] * stride_qh + dim_range0[ + None, None, :] * stride_qd + off_q1 = cur_seq_range[:, None, None] * stride_qbs + cur_head_range[None, :, None] * stride_qh + dim_range1[ + None, None, :] * stride_qd + + cos_range = tl.arange(0, BLOCK_DMODEL // 2) + off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd + + q0 = tl.load(Q + off_q0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), + other=0.0) + q1 = tl.load(Q + off_q1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), + other=0.0) + + cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + + out0 = q0 * cos - q1 * sin + out1 = q0 * sin + q1 * cos + + tl.store(Q + off_q0, + out0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H)) + tl.store(Q + off_q1, + out1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H)) + + return + + @staticmethod + @torch.no_grad() + def rotary_emb_fwd(q, cos, sin): + total_len = q.shape[0] + head_num = q.shape[1] + head_dim = q.shape[2] // 2 + assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" + BLOCK_HEAD = 4 + BLOCK_SEQ = 32 + grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + Llama2Forwards._rotary_kernel[grid]( + q, + cos, + sin, + q.stride(0), + q.stride(1), + q.stride(2), + cos.stride(0), + cos.stride(1), + sin.stride(0), + sin.stride(1), + total_len, + head_num, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index 81528cebd70d..61c1b5be8487 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -681,14 +681,12 @@ def token_att_fwd2(prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len): def token_attn(q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, other_kv_index): total_token_num = k.shape[0] - head_num = k.shape[1] - batch_size = kv_cache_seq_len.shape[0] - calcu_shape1 = (batch_size, head_num, k.shape[2]) - + batch_size, head_num, head_dim = q.shape + calcu_shape1 = (batch_size, head_num, head_dim) att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") Llama2TokenAttentionForwards.token_att_fwd( - q.view(calcu_shape1), + q, k, att_m_tensor, kv_cache_loc, @@ -705,8 +703,8 @@ def token_attn(q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq Llama2TokenAttentionForwards.token_att_fwd2(prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch) - prob = None + prob = None return elif triton.__version__ >= "2.1.0": diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index 8f15f19929a4..e49ebbc67e16 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -38,18 +38,16 @@ def run_chatglm2_test(test_config): #init_to_get_rotary(model.model, base=10000) model = model.half() - text = ["how is weather today?", "i am "] + text = ["how is the weather today?"] input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True) shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, inference_only=True) - #print("input ids ", input_ids) infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) outputs = infer_engine.generate(input_ids, **generate_kwargs) - print("outputs.shape: ", outputs[0].shape) - - print("outputs: ", outputs[0]) + # print("outputs.shape: ", outputs[0].shape) + # print("outputs: ", outputs[0]) if not dist.is_initialized() or dist.get_rank() == 0: for o in outputs: output_text = tokenizer.decode(o) diff --git a/tests/test_infer_ops/triton/test_llama2_token_attn.py b/tests/test_infer_ops/triton/test_llama2_token_attn.py new file mode 100644 index 000000000000..5087b304dca1 --- /dev/null +++ b/tests/test_infer_ops/triton/test_llama2_token_attn.py @@ -0,0 +1,69 @@ +import math + +import pytest +import torch +from packaging import version + +try: + import triton + import triton.language as tl + + from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + + +def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): + xq = xq.view(bs, 1, num_head, head_dim) + xk = xk.view(bs, seqlen, num_head, head_dim) + xv = xv.view(bs, seqlen, num_head, head_dim) + + logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5) + prob = torch.softmax(logics, dim=1) + prob = prob.view(bs, seqlen, num_head, 1) + + return torch.sum(prob * xv, dim=1, keepdim=False) + + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, + reason="triton requires cuda version to be higher than 11.4") +def test(): + + Z, head_num, seq_len, head_dim = 2, 32, 2048, 128 + dtype = torch.float16 + + # attn out: 2,4096 + q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) + v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) + o = torch.empty_like() + #o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) + + max_kv_cache_len = seq_len + kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") + kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda") + kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") + other_kv_index = 2048 + + kv_cache_seq_len[:] = seq_len + kv_cache_start_loc[0] = 0 + kv_cache_start_loc[1] = seq_len + + for i in range(Z): + kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda") + + Llama2TokenAttentionForwards.token_attn(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, + max_kv_cache_len, other_kv_index) + torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim) + + print("max ", torch.max(torch.abs(torch_out - o))) + print("mean ", torch.mean(torch.abs(torch_out - o))) + assert torch.allclose(torch_out, o, atol=1e-3, rtol=0) + + +if __name__ == "__main__": + test() From 321127b5f1e927d25ee519a02069a72ffb625275 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 18 Sep 2023 18:52:04 +0800 Subject: [PATCH 14/22] change some logic --- .../inference/tensor_parallel/engine.py | 10 ++- .../tensor_parallel/modeling/chatglm2.py | 72 ++++++++++--------- .../tensor_parallel/policies/chatglm2.py | 1 - colossalai/kernel/triton/context_attention.py | 4 -- tests/test_infer/test_chatglm2_infer.py | 3 +- 5 files changed, 47 insertions(+), 43 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 8c7712376e0e..52cb9880adf4 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -66,6 +66,8 @@ def __init__(self, num_hidden_layers = model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") \ else model.config.num_layers self.layer_num = num_hidden_layers + self.multi_query_group_num = model.config.multi_query_group_num if hasattr(model.config, + "multi_query_group_num") else 0 self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.cache_manager = None @@ -79,8 +81,12 @@ def _init_manager(self) -> None: assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig" assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" self.head_num //= self.tp_size # update sharded number of heads - self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim, - self.layer_num) + if self.multi_query_group_num: + self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.multi_query_group_num, + self.head_dim, self.layer_num) + else: + self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim, + self.layer_num) def _optimize_model(self, model: nn.Module) -> None: """ diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index fb8e24fd7e01..b20a32d5e647 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -313,7 +313,7 @@ def chatglm_encoder_forward( output_hidden_states: Optional[bool] = False, infer_state: Optional[BatchInferState] = None, ): - + hidden_states = hidden_states.transpose(0, 1).contiguous() if not kv_caches: kv_caches = [None for _ in range(self.num_layers)] presents = () if use_cache else None @@ -342,6 +342,8 @@ def chatglm_encoder_forward( all_hidden_states = all_hidden_states + (hidden_states,) # Final layer norm. + hidden_states = hidden_states.transpose(0, 1).contiguous() + if self.post_layer_norm: hidden_states = self.final_layernorm(hidden_states) @@ -401,13 +403,12 @@ def chatglm_flash_attn_kvcache_forward( ): assert use_cache is True, "use_cache should be set to True using this chatglm attention" # hidden_states: [sq, b, h] - batch_size = hidden_states.shape[1] + batch_size = hidden_states.shape[0] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - mixed_x_layer = self.query_key_value(hidden_states) if self.multi_query_attention: @@ -456,37 +457,38 @@ def chatglm_flash_attn_kvcache_forward( #The shape of key value pair will return to [sq, b , num_heads, num_hidden_size] after rotary embedding, the logic is kept same as original - if self.multi_query_attention: - key_layer = key_layer.unsqueeze(-2) - key_layer = key_layer.expand( - -1, - -1, - -1, - self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, - -1, - ) - key_layer = key_layer.contiguous().view(key_layer.size()[:2] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - )) - value_layer = value_layer.unsqueeze(-2) - value_layer = value_layer.expand( - -1, - -1, - -1, - self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, - -1, - ) - value_layer = value_layer.contiguous().view(value_layer.size()[:2] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - )) - - # reshape q k v to [bsz*sql, num_heads, head_dim] 2*1 ,32 ,128 + # if self.multi_query_attention: + # key_layer = key_layer.unsqueeze(-2) + # key_layer = key_layer.expand( + # -1, + # -1, + # -1, + # self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + # -1, + # ) + # key_layer = key_layer.contiguous().view(key_layer.size()[:2] + ( + # self.num_attention_heads_per_partition, + # self.hidden_size_per_attention_head, + # )) + # value_layer = value_layer.unsqueeze(-2) + # value_layer = value_layer.expand( + # -1, + # -1, + # -1, + # self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, + # -1, + # ) + # value_layer = value_layer.contiguous().view(value_layer.size()[:2] + ( + # self.num_attention_heads_per_partition, + # self.hidden_size_per_attention_head, + # )) + + # reshape q k v to [bsz*sql, num_heads, head_dim] 2*1 ,32/2 ,128 query_layer = query_layer.reshape(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - key_layer = key_layer.reshape(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - value_layer = value_layer.reshape(-1, self.num_attention_heads_per_partition, + key_layer = key_layer.reshape(-1, self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head) + value_layer = value_layer.reshape(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) if infer_state.is_context_stage: # first token generation: @@ -510,7 +512,6 @@ def chatglm_flash_attn_kvcache_forward( infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] - # 2, 32, 128 cache_k.copy_(key_layer) cache_v.copy_(value_layer) else: @@ -525,6 +526,7 @@ def chatglm_flash_attn_kvcache_forward( infer_state.decode_layer_id][:infer_state.decode_mem_end, :, :] cache_v = infer_state.cache_manager.value_buffer[ infer_state.decode_layer_id][:infer_state.decode_mem_end, :, :] + # ================================== # core attention computation is replaced by triton kernel # ================================== @@ -535,8 +537,8 @@ def chatglm_flash_attn_kvcache_forward( #print('after attention',torch.isnan(attn_output).any()) # ================= - # Output:[sq, b, h] ,it is kept same as original. + # Output:[b,sq, h] # ================= - output = self.dense(attn_output).reshape(-1, batch_size, self.projection_size) + output = self.dense(attn_output).reshape(batch_size, -1, self.projection_size) return output, kv_cache diff --git a/colossalai/inference/tensor_parallel/policies/chatglm2.py b/colossalai/inference/tensor_parallel/policies/chatglm2.py index f0bd50f5b54b..472bb1363e38 100644 --- a/colossalai/inference/tensor_parallel/policies/chatglm2.py +++ b/colossalai/inference/tensor_parallel/policies/chatglm2.py @@ -52,7 +52,6 @@ def module_policy(self): target_key=SelfAttention) # for rmsnorm and others, we need to check the shape - return policy def postprocess(self): diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 897071c1f86c..583eed8c38b3 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -452,7 +452,6 @@ def llama2_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lq**0.5) # 计算scale系数 batch, head = b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1] @@ -490,9 +489,6 @@ def llama2_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): return elif triton.__version__ == "2.0.0": - #print('query',q[0]) - # print('key',k[0]) - # print('value',v[0]) BLOCK = 128 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index e49ebbc67e16..ccf7e6d845c5 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -38,7 +38,7 @@ def run_chatglm2_test(test_config): #init_to_get_rotary(model.model, base=10000) model = model.half() - text = ["how is the weather today?"] + text = ["how is the weather today?", "i am ", "你好?"] input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True) shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, inference_only=True) @@ -46,6 +46,7 @@ def run_chatglm2_test(test_config): generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) outputs = infer_engine.generate(input_ids, **generate_kwargs) + # print("outputs.shape: ", outputs[0].shape) # print("outputs: ", outputs[0]) if not dist.is_initialized() or dist.get_rank() == 0: From 065c834a62adb6b9681bdf8750786a8aba3810c0 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 19 Sep 2023 10:56:27 +0800 Subject: [PATCH 15/22] fix bugs --- .../inference/tensor_parallel/engine.py | 5 -- .../tensor_parallel/modeling/chatglm2.py | 56 +------------------ .../tensor_parallel/policies/chatglm2.py | 8 +-- colossalai/kernel/triton/context_attention.py | 2 +- 4 files changed, 6 insertions(+), 65 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 52cb9880adf4..8cd42b4ddf3e 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -260,11 +260,6 @@ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch model = self.model.transformer setattr(model, 'infer_state', batch_infer_state) - # outputs = self.model.forward(input_tokens['input_ids'], attention_mask=input_tokens['attention_mask']) - # outputs = self.model.forward(input_tokens['input_ids'][:, 0].unsqueeze(1), attention_mask=input_tokens['attention_mask']) - # outputs = self.model.forward(input_tokens['input_ids'][:, 1].unsqueeze(1), attention_mask=input_tokens['attention_mask']) - - # FOR test chatglm2 outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False) # NOTE In future development, we're going to let the scheduler to handle the cache, diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index b20a32d5e647..178cd7974822 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -22,19 +22,6 @@ from ._utils import copy_kv_to_mem_cache -try: - from vllm import layernorm_ops, pos_encoding_ops - rms_norm = layernorm_ops.rms_norm - rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox - HAS_VLLM_KERNERL = True -except: - print("fall back to original rotary_embedding_neox of huggingface") - print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") - print( - "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch" - ) - HAS_VLLM_KERNERL = False - # This func is same as Llama model init_to_get_rotary, we should move them into _utils.py def _init_to_get_rotary(self, base=10000): @@ -262,19 +249,10 @@ def chatglm_model_forward( infer_state.cache_manager.past_key_values_length, padding_mask=attention_mask) - # Rotary positional embeddings - rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - if position_ids is not None: - rotary_pos_emb = rotary_pos_emb[position_ids] - else: - rotary_pos_emb = rotary_pos_emb[None, :seq_length] - rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() - # Run encoder. hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( inputs_embeds, full_attention_mask, - rotary_pos_emb=rotary_pos_emb, kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states, @@ -307,7 +285,6 @@ def chatglm_encoder_forward( self: GLMTransformer, hidden_states, attention_mask, - rotary_pos_emb, kv_caches=None, use_cache: Optional[bool] = True, output_hidden_states: Optional[bool] = False, @@ -402,11 +379,8 @@ def chatglm_flash_attn_kvcache_forward( infer_state: Optional[BatchInferState] = None, ): assert use_cache is True, "use_cache should be set to True using this chatglm attention" - # hidden_states: [sq, b, h] + # hidden_states: original :[sq, b, h] --> this [b, sq, h] batch_size = hidden_states.shape[0] - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer = self.query_key_value(hidden_states) @@ -455,34 +429,6 @@ def chatglm_flash_attn_kvcache_forward( key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin) - #The shape of key value pair will return to [sq, b , num_heads, num_hidden_size] after rotary embedding, the logic is kept same as original - - # if self.multi_query_attention: - # key_layer = key_layer.unsqueeze(-2) - # key_layer = key_layer.expand( - # -1, - # -1, - # -1, - # self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, - # -1, - # ) - # key_layer = key_layer.contiguous().view(key_layer.size()[:2] + ( - # self.num_attention_heads_per_partition, - # self.hidden_size_per_attention_head, - # )) - # value_layer = value_layer.unsqueeze(-2) - # value_layer = value_layer.expand( - # -1, - # -1, - # -1, - # self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, - # -1, - # ) - # value_layer = value_layer.contiguous().view(value_layer.size()[:2] + ( - # self.num_attention_heads_per_partition, - # self.hidden_size_per_attention_head, - # )) - # reshape q k v to [bsz*sql, num_heads, head_dim] 2*1 ,32/2 ,128 query_layer = query_layer.reshape(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) diff --git a/colossalai/inference/tensor_parallel/policies/chatglm2.py b/colossalai/inference/tensor_parallel/policies/chatglm2.py index 472bb1363e38..cb223370a65d 100644 --- a/colossalai/inference/tensor_parallel/policies/chatglm2.py +++ b/colossalai/inference/tensor_parallel/policies/chatglm2.py @@ -32,21 +32,21 @@ def module_policy(self): self.shard_config._infer() model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward - method_replacement = {'forward': partial(model_infer_forward)} + method_replacement = {'forward': model_infer_forward} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel) encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward - method_replacement = {'forward': partial(encoder_infer_forward)} + method_replacement = {'forward': encoder_infer_forward} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMTransformer) encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward - method_replacement = {'forward': partial(encoder_layer_infer_forward)} + method_replacement = {'forward': encoder_layer_infer_forward} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock) attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward - method_replacement = {'forward': partial(attn_infer_forward)} + method_replacement = {'forward': attn_infer_forward} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=SelfAttention) diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 583eed8c38b3..afe6df7141f7 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -238,7 +238,7 @@ def _fwd_kernel_latest( V, sm_scale, B_Start_Loc, - B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 + B_Seqlen, Out, stride_qbs, stride_qh, From 212e01bc90d6c80c0992211b4c0948cd3fbbfaec Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 19 Sep 2023 13:17:52 +0800 Subject: [PATCH 16/22] change some logics --- .../inference/tensor_parallel/engine.py | 1 + .../modeling/chatglm2_6b/modeling_chatglm.py | 2 - tests/kit/model_zoo/transformers/chatglm2.py | 4 +- tests/test_infer/test_chatglm2_infer.py | 44 +++++++++---------- 4 files changed, 25 insertions(+), 26 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 8cd42b4ddf3e..9fb7d70df92a 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -82,6 +82,7 @@ def _init_manager(self) -> None: assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" self.head_num //= self.tp_size # update sharded number of heads if self.multi_query_group_num: + assert self.multi_query_group_num % self.tp_size == 0, f"Cannot shard {self.multi_query_group_num} query groups with tp size {self.tp_size}" self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.multi_query_group_num, self.head_dim, self.layer_num) else: diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py index 8267afc91622..9358df05fa09 100644 --- a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py +++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py @@ -385,12 +385,10 @@ class SelfAttention(torch.nn.Module): def __init__(self, config: ChatGLMConfig, layer_number, device=None): super(SelfAttention, self).__init__() self.layer_number = max(1, layer_number) - self.projection_size = config.kv_channels * config.num_attention_heads # Per attention head and per partition values. self.hidden_size_per_attention_head = (self.projection_size // config.num_attention_heads) self.num_attention_heads_per_partition = config.num_attention_heads - self.multi_query_attention = config.multi_query_attention self.qkv_hidden_size = 3 * self.projection_size if self.multi_query_attention: diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index d543df00bdfa..d74a1b29b788 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -36,8 +36,10 @@ def data_gen_for_conditional_generation(): config = ChatGLMConfig(num_layers=2, padded_vocab_size=65024, - hidden_size=64, + hidden_size=1024, num_attention_heads=8, + multi_query_attention=True, + multi_query_group_num=2, rmsnorm=True, original_rope=True, use_cache=True, diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index ccf7e6d845c5..52bb87f3d9dd 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -8,12 +8,12 @@ from transformers import AutoModel, AutoTokenizer import colossalai -from colossalai.cluster import ProcessGroupMesh from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' TPSIZE = 1 @@ -28,31 +28,29 @@ }]) def run_chatglm2_test(test_config): - chatglm2_model_path = "/home/lccjh/data2/lccjh/chatglm2-6b" - assert os.path.isdir(chatglm2_model_path) is True + sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm_for_conditional_generation') tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) # pad_token_id = 0 - - model = ChatGLMForConditionalGeneration.from_pretrained(chatglm2_model_path, pad_token_id=tokenizer.eos_token_id) - #init_to_get_rotary(model.model, base=10000) - model = model.half() - - text = ["how is the weather today?", "i am ", "你好?"] - input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True) - shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, - inference_only=True) - infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - - generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) - outputs = infer_engine.generate(input_ids, **generate_kwargs) - - # print("outputs.shape: ", outputs[0].shape) - # print("outputs: ", outputs[0]) - if not dist.is_initialized() or dist.get_rank() == 0: - for o in outputs: - output_text = tokenizer.decode(o) - print(output_text) + for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): + orig_model = model_fn() + orig_model = orig_model.half() + text = ["how is the weather today?"] + input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True) + shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, + inference_only=True) + infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + + generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) + outputs = infer_engine.generate(input_ids, **generate_kwargs) + assert outputs is not None + + # print("outputs.shape: ", outputs[0].shape) + # print("outputs: ", outputs[0]) + if not dist.is_initialized() or dist.get_rank() == 0: + for o in outputs: + output_text = tokenizer.decode(o) + print(output_text) def check_chatglm2(rank, world_size, port): From 0ddcfb1455ea462fbf5a77053c4dad3b8b3cdabd Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 19 Sep 2023 13:42:29 +0800 Subject: [PATCH 17/22] add --- colossalai/inference/tensor_parallel/engine.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 9fb7d70df92a..c18e6fd38187 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -82,9 +82,11 @@ def _init_manager(self) -> None: assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" self.head_num //= self.tp_size # update sharded number of heads if self.multi_query_group_num: + # NOTE the logic of MQA tensor parallelism should be specified. assert self.multi_query_group_num % self.tp_size == 0, f"Cannot shard {self.multi_query_group_num} query groups with tp size {self.tp_size}" - self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.multi_query_group_num, - self.head_dim, self.layer_num) + self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, + self.multi_query_group_num // self.tp_size, self.head_dim, + self.layer_num) else: self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num) From 1c36ad96e938d91730565b80d9b5f616416c7c3d Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 19 Sep 2023 13:54:54 +0800 Subject: [PATCH 18/22] add --- tests/kit/model_zoo/torchrec/__init__.py | 2 +- tests/kit/model_zoo/transformers/chatglm2.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 43952e6998cf..4a19f2449602 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -from .torchrec import * +#from .torchrec import * diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index d74a1b29b788..c62d121db223 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -36,10 +36,11 @@ def data_gen_for_conditional_generation(): config = ChatGLMConfig(num_layers=2, padded_vocab_size=65024, - hidden_size=1024, - num_attention_heads=8, + hidden_size=64, + num_attention_heads=4, multi_query_attention=True, multi_query_group_num=2, + kv_channels=16, rmsnorm=True, original_rope=True, use_cache=True, From 5da13fa0173e90a1771662423782525e315f47a4 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Tue, 19 Sep 2023 13:54:58 +0800 Subject: [PATCH 19/22] add --- tests/kit/model_zoo/transformers/chatglm2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index c62d121db223..2d8826ea05f6 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -36,8 +36,8 @@ def data_gen_for_conditional_generation(): config = ChatGLMConfig(num_layers=2, padded_vocab_size=65024, - hidden_size=64, - num_attention_heads=4, + hidden_size=128, + num_attention_heads=8, multi_query_attention=True, multi_query_group_num=2, kv_channels=16, From 9cf099cd5dda412e8bf58ef334d4a6bbff351b92 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Wed, 20 Sep 2023 18:48:26 +0800 Subject: [PATCH 20/22] fix --- .../kernel/triton/rotary_embedding_kernel.py | 2 +- .../kernel/triton/token_attention_kernel.py | 4 ++-- tests/kit/model_zoo/torchrec/__init__.py | 2 +- .../triton/test_llama2_token_attn.py | 24 ++++++++----------- 4 files changed, 14 insertions(+), 18 deletions(-) diff --git a/colossalai/kernel/triton/rotary_embedding_kernel.py b/colossalai/kernel/triton/rotary_embedding_kernel.py index ea849ea53119..fd74ba817551 100644 --- a/colossalai/kernel/triton/rotary_embedding_kernel.py +++ b/colossalai/kernel/triton/rotary_embedding_kernel.py @@ -122,7 +122,7 @@ def _rotary_kernel( stride_sinbs, stride_sind, max_total_len, - H, # N_CTX 代表要计算的上下文长度 + H, # N_CTX BLOCK_HEAD: tl.constexpr, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index 7a0ffdf279e9..c27394f0f9cf 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -425,7 +425,7 @@ def _fwd_kernel( stride_od, stride_b_loc_b, stride_b_loc_s, - other_kv_index, # 避免读取到nan的数据 + other_kv_index, # avoid nan information kv_group_num, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -693,7 +693,7 @@ def _fwd_kernel_token_att2( B_Loc, B_Start_Loc, B_Seqlen, - max_input_len, # B_Start_Loc 保存的是如果连续存储时候的累加输入和 + max_input_len, # B_Start_Loc cumsum of input lens if continuous stride_b_loc_b, stride_b_loc_s, stride_ph, diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 4a19f2449602..43952e6998cf 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -#from .torchrec import * +from .torchrec import * diff --git a/tests/test_infer_ops/triton/test_llama2_token_attn.py b/tests/test_infer_ops/triton/test_llama2_token_attn.py index 5087b304dca1..c22f70211d4f 100644 --- a/tests/test_infer_ops/triton/test_llama2_token_attn.py +++ b/tests/test_infer_ops/triton/test_llama2_token_attn.py @@ -1,20 +1,18 @@ -import math - import pytest import torch from packaging import version try: - import triton - import triton.language as tl + pass from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): @@ -29,10 +27,10 @@ def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): return torch.sum(prob * xv, dim=1, keepdim=False) -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test(): - Z, head_num, seq_len, head_dim = 2, 32, 2048, 128 dtype = torch.float16 @@ -41,7 +39,7 @@ def test(): k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) o = torch.empty_like() - #o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) + # o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) max_kv_cache_len = seq_len kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") @@ -56,12 +54,10 @@ def test(): for i in range(Z): kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda") - Llama2TokenAttentionForwards.token_attn(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, - max_kv_cache_len, other_kv_index) + Llama2TokenAttentionForwards.token_attn( + q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, other_kv_index + ) torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim) - - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) assert torch.allclose(torch_out, o, atol=1e-3, rtol=0) From 0dc780dfbcfe79737975e4d55da461a71100733c Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 21 Sep 2023 16:47:34 +0800 Subject: [PATCH 21/22] fix tests --- .../tensor_parallel/modeling/chatglm2.py | 200 +++++++++++------- .../modeling/chatglm2_6b/modeling_chatglm.py | 1 - tests/kit/model_zoo/torchrec/__init__.py | 2 +- tests/kit/model_zoo/transformers/chatglm2.py | 12 ++ tests/test_infer/test_chatglm2_infer.py | 64 +++--- 5 files changed, 171 insertions(+), 108 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index 178cd7974822..4b1bc601f436 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -1,14 +1,12 @@ import os -from typing import List, Optional, Tuple +from typing import Optional, Tuple -import numpy as np import torch from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton.context_attention import llama2_context_attn_fwd -from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest from colossalai.kernel.triton.rotary_embedding_kernel import Llama2Forwards from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( @@ -45,11 +43,11 @@ def _init_to_get_rotary(self, base=10000): if ntk_alpha > 1: print(f"Note: NTK enabled, alpha set to {ntk_alpha}") max_seq_len *= ntk_alpha - base = base * (ntk_alpha**(self.head_dim_ / (self.head_dim_ - 2))) #Base change formula + base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula except: pass n_elem = self.config.head_dim_ // 2 - inv_freq = 1.0 / (base**(torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) @@ -102,7 +100,7 @@ def chatglm_for_conditional_generation_forward( return_last_logit: Optional[bool] = False, ): use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict infer_state = self.infer_state if input_ids is not None and inputs_embeds is not None: @@ -126,8 +124,9 @@ def chatglm_for_conditional_generation_forward( if use_cache and seq_length != 1: infer_state.is_context_stage = True infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) - infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, - infer_state.context_mem_index) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) else: infer_state.is_context_stage = False alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) @@ -149,28 +148,31 @@ def chatglm_for_conditional_generation_forward( # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index - #related to rotary embedding + # related to rotary embedding if infer_state.is_context_stage: - infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1) + position_ids.view(-1).shape[0], -1 + ) infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1) + position_ids.view(-1).shape[0], -1 + ) else: seq_len = infer_state.seq_len infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() - transformer_outputs = self.transformer(input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - infer_state=infer_state) + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + infer_state=infer_state, + ) hidden_states = transformer_outputs[0] if return_last_logit: @@ -218,10 +220,11 @@ def chatglm_model_forward( return_dict: Optional[bool] = None, infer_state: BatchInferState = None, ): - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, seq_length = input_ids.shape if inputs_embeds is None: @@ -244,10 +247,9 @@ def chatglm_model_forward( ) if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = get_masks(self, - input_ids, - infer_state.cache_manager.past_key_values_length, - padding_mask=attention_mask) + full_attention_mask = get_masks( + self, input_ids, infer_state.cache_manager.past_key_values_length, padding_mask=attention_mask + ) # Run encoder. hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( @@ -256,7 +258,8 @@ def chatglm_model_forward( kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states, - infer_state=infer_state) + infer_state=infer_state, + ) # update indices # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") @@ -266,12 +269,16 @@ def chatglm_model_forward( infer_state.cache_manager.past_key_values_length += seq_length if not return_dict: - return tuple(v for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - ] if v is not None) + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -381,7 +388,6 @@ def chatglm_flash_attn_kvcache_forward( assert use_cache is True, "use_cache should be set to True using this chatglm attention" # hidden_states: original :[sq, b, h] --> this [b, sq, h] batch_size = hidden_states.shape[0] - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer = self.query_key_value(hidden_states) @@ -394,18 +400,27 @@ def chatglm_flash_attn_kvcache_forward( ], dim=-1, ) - query_layer = query_layer.view(query_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - )) - key_layer = key_layer.view(key_layer.size()[:-1] + ( - self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head, - )) - value_layer = value_layer.view(value_layer.size()[:-1] + ( - self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head, - )) + query_layer = query_layer.view( + query_layer.size()[:-1] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) else: new_tensor_shape = mixed_x_layer.size()[:-1] + ( @@ -419,68 +434,103 @@ def chatglm_flash_attn_kvcache_forward( cos, sin = infer_state.position_cos, infer_state.position_sin Llama2Forwards.rotary_emb_fwd( - query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin) + query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin + ) if self.multi_query_attention: Llama2Forwards.rotary_emb_fwd( - key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head), cos, - sin) + key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head), + cos, + sin, + ) else: Llama2Forwards.rotary_emb_fwd( - key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, - sin) + key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), + cos, + sin, + ) # reshape q k v to [bsz*sql, num_heads, head_dim] 2*1 ,32/2 ,128 - query_layer = query_layer.reshape(-1, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) - key_layer = key_layer.reshape(-1, self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head) - value_layer = value_layer.reshape(-1, self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head) + query_layer = query_layer.reshape( + -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ) + key_layer = key_layer.reshape( + -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head + ) + value_layer = value_layer.reshape( + -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head + ) if infer_state.is_context_stage: # first token generation: # copy key and value calculated in current step to memory manager - copy_kv_to_mem_cache(infer_state.decode_layer_id, key_layer, value_layer, infer_state.context_mem_index, - infer_state.cache_manager) + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_layer, + value_layer, + infer_state.context_mem_index, + infer_state.cache_manager, + ) attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) # NOTE: no bug in context attn fwd (del it ) llama2_context_attn_fwd( - query_layer, key_layer, value_layer, + query_layer, + key_layer, + value_layer, attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), - infer_state.start_loc, infer_state.seq_len, infer_state.seq_length_with_past) + infer_state.start_loc, + infer_state.seq_len, + infer_state.seq_length_with_past, + ) else: if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ - infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ - infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] cache_k.copy_(key_layer) cache_v.copy_(value_layer) else: # if decode is not contiguous, use triton kernel to copy key and value cache # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head - copy_kv_to_mem_cache(infer_state.decode_layer_id, key_layer, value_layer, infer_state.decode_mem_index, - infer_state.cache_manager) + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_layer, + value_layer, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) # second token and follows attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) - cache_k = infer_state.cache_manager.key_buffer[ - infer_state.decode_layer_id][:infer_state.decode_mem_end, :, :] - cache_v = infer_state.cache_manager.value_buffer[ - infer_state.decode_layer_id][:infer_state.decode_mem_end, :, :] + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + : infer_state.decode_mem_end, :, : + ] # ================================== # core attention computation is replaced by triton kernel # ================================== - Llama2TokenAttentionForwards.token_attn(query_layer, cache_k, cache_v, attn_output, infer_state.block_loc, - infer_state.start_loc, infer_state.seq_len, - infer_state.max_len_in_batch, infer_state.other_kv_index) + Llama2TokenAttentionForwards.token_attn( + query_layer, + cache_k, + cache_v, + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.max_len_in_batch, + infer_state.other_kv_index, + ) - #print('after attention',torch.isnan(attn_output).any()) + # print('after attention',torch.isnan(attn_output).any()) # ================= # Output:[b,sq, h] diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py index 2168b370ea86..cbb25b5b1f4c 100644 --- a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py +++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py @@ -443,7 +443,6 @@ def forward( # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer = self.query_key_value(hidden_states) - if self.multi_query_attention: (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 43952e6998cf..458d7bc81ce8 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -from .torchrec import * +# from .torchrec import * diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index dbe0f1e14e51..f4369cb7d171 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -35,6 +35,18 @@ def data_gen_for_conditional_generation(): loss_fn = lambda x: x.loss config = ChatGLMConfig( + num_layers=2, + padded_vocab_size=65024, + hidden_size=64, + num_attention_heads=8, + kv_channels=16, + rmsnorm=True, + original_rope=True, + use_cache=True, + torch_dtype=torch.float32, +) + +infer_config = ChatGLMConfig( num_layers=2, padded_vocab_size=65024, hidden_size=128, diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index 52bb87f3d9dd..699ba7b52fe0 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -1,61 +1,63 @@ import os -import numpy as np import pytest import torch import torch.distributed as dist from packaging import version -from transformers import AutoModel, AutoTokenizer +from transformers import AutoTokenizer import colossalai from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo +from tests.kit.model_zoo.transformers.chatglm2 import infer_config -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" TPSIZE = 1 BATCH_SIZE = 8 MAX_INPUT_LEN = 12 MAX_OUTPUT_LEN = 100 -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") -@parameterize('test_config', [{ - 'tp_size': TPSIZE, -}]) +@parameterize( + "test_config", + [ + { + "tp_size": TPSIZE, + } + ], +) def run_chatglm2_test(test_config): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm_for_conditional_generation') - tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) # pad_token_id = 0 - for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - orig_model = model_fn() - orig_model = orig_model.half() - text = ["how is the weather today?"] - input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True) - shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, - inference_only=True) - infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + model_fn = lambda: ChatGLMForConditionalGeneration(infer_config, empty_init=False) + orig_model = model_fn() + orig_model = orig_model.half() + text = ["how is the weather today?"] + input_ids = tokenizer.batch_encode_plus(text, return_tensors="pt", padding=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + ) + infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) - outputs = infer_engine.generate(input_ids, **generate_kwargs) - assert outputs is not None + generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) + outputs = infer_engine.generate(input_ids, **generate_kwargs) + assert outputs is not None - # print("outputs.shape: ", outputs[0].shape) - # print("outputs: ", outputs[0]) - if not dist.is_initialized() or dist.get_rank() == 0: - for o in outputs: - output_text = tokenizer.decode(o) - print(output_text) + # print("outputs.shape: ", outputs[0].shape) + # print("outputs: ", outputs[0]) + if not dist.is_initialized() or dist.get_rank() == 0: + for o in outputs: + output_text = tokenizer.decode(o) + print(output_text) def check_chatglm2(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_chatglm2_test() From 6c477a4e0e92eb5f3b024d54f70c8085c9477b53 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Thu, 21 Sep 2023 16:48:43 +0800 Subject: [PATCH 22/22] fix --- tests/kit/model_zoo/torchrec/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 458d7bc81ce8..43952e6998cf 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -# from .torchrec import * +from .torchrec import *