diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index a6ee58f1e00d..9768fc425628 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -140,7 +140,7 @@ def bloom_model_forward( # if self.cache_manager.past_key_values_length > 0: if infer_state.cache_manager.past_key_values_length > 0: # update the past key values length in cache manager, - # TODO use BatchInferState.past_key_values_length instead the one in cache manager + # NOTE use BatchInferState.past_key_values_length instead the one in cache manager past_key_values_length = infer_state.cache_manager.past_key_values_length seq_length_with_past = seq_length_with_past + past_key_values_length @@ -178,7 +178,7 @@ def bloom_model_forward( else: attention_mask = attention_mask.to(hidden_states.device) - # TODO revise: we might want to store a single 1D alibi(length is #heads) in model, + # NOTE revise: we might want to store a single 1D alibi(length is #heads) in model, # or store to BatchInferState to prevent re-calculating # When we have multiple process group (e.g. dp together with tp), we need to pass the pg to here # alibi = generate_alibi(self.num_heads).contiguous().cuda() @@ -445,6 +445,9 @@ def bloom_attention_forward( mem_manager = infer_state.cache_manager layer_id = infer_state.decode_layer_id + if layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_length # += 1 + if infer_state.is_context_stage: # context process max_input_len = q_length @@ -461,10 +464,6 @@ def bloom_attention_forward( bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) context_layer = output.view(batch_size, q_length, H * D_HEAD) - # record the length of past key values cache when entering the first attention layer in bloom block, - # since we won't return past_key_value_cache right now - if layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length = q_length # seq_len else: # query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) # need shape: batch_size, H, D_HEAD (q_length == 1), input q shape : (batch_size, q_length(1), H, D_HEAD) @@ -485,20 +484,15 @@ def bloom_attention_forward( copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id]) copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id]) - b_start_loc = infer_state.start_loc[:batch_size] - b_loc = infer_state.block_loc[:batch_size, :] - b_seq_len = infer_state.seq_len[:batch_size] - max_len_in_batch = mem_manager.past_key_values_length + q_length + b_start_loc = infer_state.start_loc + b_loc = infer_state.block_loc + b_seq_len = infer_state.seq_len output = torch.empty_like(q) token_attention_fwd(q, mem_manager.key_buffer[layer_id], mem_manager.value_buffer[layer_id], output, b_loc, - b_start_loc, b_seq_len, max_len_in_batch, alibi) + b_start_loc, b_seq_len, infer_state.cache_manager.past_key_values_length, alibi) context_layer = output.view(batch_size, q_length, H * D_HEAD) - if layer_id == 0: # once per model.forward - assert infer_state.cache_manager.past_key_values_length != 0 - infer_state.cache_manager.past_key_values_length += q_length # += 1 - # update layer id infer_state.decode_layer_id += 1 diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 0d8ed5dc442f..82f294163fd7 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -54,12 +54,16 @@ def llama_model_forward( infer_state = self.infer_state b_seq_len_numpy = infer_state.seq_len.cpu().numpy() - position_ids = torch.from_numpy( - np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0)).cuda() - # this equals - infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) + if HAS_VLLM_KERNERL: + position_ids = torch.from_numpy( + np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0)).cuda() + + # this equals + infer_state.position_cos = torch.index_select(self._cos_cached, 0, + position_ids).view(position_ids.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, + position_ids).view(position_ids.shape[0], -1) return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -241,64 +245,70 @@ def llama_flash_attn_kvcache_forward( bsz, q_len, _ = hidden_states.size() - # TODO might think about better way to handle transposed k and v + # NOTE might think about better way to handle transposed k and v # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head] # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - key_states_transposed = key_states.transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + key_states_transposed = key_states.transpose(1, 2) - # cos, sin = self.rotary_emb(value_states_transposed, seq_len=kv_seq_len) - cos, sin = infer_state.position_cos, infer_state.position_sin + # NOTE might want to revise + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + if infer_state.decode_layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_len # seq_len if HAS_VLLM_KERNERL: + cos, sin = infer_state.position_cos, infer_state.position_sin cos_sin_cache = torch.cat((cos, sin), dim=-1) rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache) key_states = key_states_transposed.transpose(1, 2) else: # NOTE: there are some issues for original rotary_embedding_neox of huggingface + value_states_transposed = value_states.transpose(1, 2) + cos, sin = self.rotary_emb(value_states_transposed, + seq_len=infer_state.cache_manager.past_key_values_length) query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, position_ids) + key_states = key_states_transposed.transpose(1, 2) def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): - num_heads = key_buffer.shape[2] - head_dim = key_buffer.shape[3] - key_buffer = key_buffer.view(-1, num_heads, head_dim) - value_buffer = value_buffer.view(-1, num_heads, head_dim) copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) return - # copy key and value calculated in current step to memory manager - if infer_state.is_context_stage: - _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index, - infer_state.cache_manager) - else: - _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.decode_mem_index, - infer_state.cache_manager) - - # FIXME might want to revise - # need some way to record the length of past key values cache - # since we won't return past_key_value_cache right now - if infer_state.decode_layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_len # seq_len - - query_states = query_states.transpose(1, 2) + key_states = key_states.reshape(-1, self.num_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_heads, self.head_dim) + query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) if infer_state.is_context_stage: # first token generation - attn_output = torch.empty_like(query_states) + # copy key and value calculated in current step to memory manager + _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index, + infer_state.cache_manager) - # calcu_shape for context_attention_fwd - calcu_shape1 = (-1, self.num_heads, self.head_dim) + attn_output = torch.empty_like(query_states) - llama_context_attn_fwd(query_states.view(calcu_shape1), key_states.view(calcu_shape1), - value_states.view(calcu_shape1), attn_output.view(calcu_shape1), - infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length) + llama_context_attn_fwd(query_states, key_states, value_states, attn_output, infer_state.start_loc, + infer_state.seq_len, infer_state.cache_manager.past_key_values_length) else: + + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + cache_k.copy_(key_states) + cache_v.copy_(value_states) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, + infer_state.decode_mem_index, infer_state.cache_manager) + # second token and follows # kv = torch.stack((key_states, value_states), dim=2) # (batch_size, seqlen, nheads, headdim)