diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 102422d6f97c..1d277aef7c20 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -309,13 +309,8 @@ def llama_flash_attn_kvcache_forward( value_states_transposed = value_states.transpose(1, 2) cos, sin = self.rotary_emb(value_states_transposed, seq_len=infer_state.cache_manager.past_key_values_length) - - rotary_positions_ids = position_ids - idx = position_ids.shape[0] - 1 - if idx >= 1: - rotary_positions_ids = [[idx]] - query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, rotary_positions_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, position_ids) query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) key_states = key_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim)