From fffb59d4d1a5d225868e44888f206d6600a7ff84 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Sat, 2 Sep 2023 15:06:37 +0800 Subject: [PATCH 1/8] fix tests --- tests/test_infer/test_bloom_infer.py | 4 ++++ tests/test_infer/test_llama_infer.py | 36 ++++++++++++++-------------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index 95ab7d5c451e..ab1df2e94e35 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -1,3 +1,4 @@ +import os import pytest import torch import torch.distributed as dist @@ -18,6 +19,9 @@ def run(): model_path = "/data3/data/model_eval_for_commerical_use/phoenix-inst-chat-7b" + if os.path.isdir(model_path) is False: + return + tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer.pad_token = tokenizer.eos_token diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index c8f852aef420..d23150d15985 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -1,17 +1,17 @@ import os -import numpy as np import pytest import torch -import torch.distributed as dist -from transformers import LlamaForCausalLM, LlamaTokenizer +import numpy as np import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from transformers import LlamaForCausalLM, LlamaTokenizer +from colossalai.cluster import ProcessGroupMesh +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.inference.tensor_parallel.engine import TPInferEngine +import torch.distributed as dist os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' TPSIZE = 2 @@ -19,22 +19,20 @@ MAX_INPUT_LEN = 12 MAX_OUTPUT_LEN = 100 - def init_to_get_rotary(self, base=10000): self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads if not hasattr(self.config, "rope_scaling"): rope_scaling_factor = 1.0 else: rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 - if hasattr(self.config, "max_sequence_length"): + if hasattr(self.config,"max_sequence_length"): max_seq_len = self.config.max_sequence_length - elif hasattr(self.config, "max_position_embeddings"): + elif hasattr(self.config,"max_position_embeddings"): max_seq_len = self.config.max_position_embeddings * rope_scaling_factor else: - max_seq_len = 2048 * rope_scaling_factor + max_seq_len = 2048 * rope_scaling_factor base = float(base) - inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / - self.config.head_dim_)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)) t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) @@ -42,33 +40,35 @@ def init_to_get_rotary(self, base=10000): self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() return - @parameterize('test_config', [{ 'tp_size': TPSIZE, }]) def run_llama_test(test_config): - + llama_model_path = "/data/scratch/llama-7b-hf" + if os.path.isdir(llama_model_path) is False: + return + tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) tokenizer.pad_token_id = tokenizer.unk_token_id model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) init_to_get_rotary(model.model, base=10000) model = model.half() - + text = "how is weather today?" input_ids = tokenizer.encode(text, return_tensors='pt', device='cuda') - + infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) shardformer = ShardFormer(shard_config=shard_config) - + infer_engine.prepare_with_shard_config(shard_config) infer_engine.shard_model_by(shardformer) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) outputs = infer_engine.generate(input_ids, generate_kwargs) print("outputs.shape: ", outputs.shape) - + print("outputs: ", outputs) output_text = tokenizer.decode(outputs[0]) From 0c3776b86e308638b5117166dc9158d0a8fbdb38 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Sat, 2 Sep 2023 15:14:02 +0800 Subject: [PATCH 2/8] clean --- tests/test_infer/test_llama_infer.py | 11 ++++++----- .../triton/test_llama_context_attention.py | 2 +- tests/test_infer_ops/triton/utils.py | 1 + 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index d23150d15985..a87826719ecf 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -3,15 +3,16 @@ import pytest import torch import numpy as np +import torch.distributed as dist +from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai from colossalai.logging import disable_existing_loggers from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from transformers import LlamaForCausalLM, LlamaTokenizer from colossalai.cluster import ProcessGroupMesh from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.inference.tensor_parallel.engine import TPInferEngine -import torch.distributed as dist + os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' TPSIZE = 2 @@ -25,12 +26,12 @@ def init_to_get_rotary(self, base=10000): rope_scaling_factor = 1.0 else: rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 - if hasattr(self.config,"max_sequence_length"): + if hasattr(self.config, "max_sequence_length"): max_seq_len = self.config.max_sequence_length - elif hasattr(self.config,"max_position_embeddings"): + elif hasattr(self.config, "max_position_embeddings"): max_seq_len = self.config.max_position_embeddings * rope_scaling_factor else: - max_seq_len = 2048 * rope_scaling_factor + max_seq_len = 2048 * rope_scaling_factor base = float(base) inv_freq = 1.0 / (base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)) t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py index e7446b289acd..b0fac1263047 100644 --- a/tests/test_infer_ops/triton/test_llama_context_attention.py +++ b/tests/test_infer_ops/triton/test_llama_context_attention.py @@ -44,7 +44,7 @@ def test_llama_context_attention(): torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched" + assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3), "outputs from triton and torch are not matched" latency_1 = benchmark(llama_context_attn_fwd, query, k, v, o, b_start, b_len, max_input_len) latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim) diff --git a/tests/test_infer_ops/triton/utils.py b/tests/test_infer_ops/triton/utils.py index 940d277cfb02..2787006ba304 100644 --- a/tests/test_infer_ops/triton/utils.py +++ b/tests/test_infer_ops/triton/utils.py @@ -36,6 +36,7 @@ def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): xv = xv.view(bs, seqlen, num_head, head_dim) mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() mask[mask == 0.] = -100000000.0 + # mask[mask==1.0] = 0 mask = mask.repeat(bs, num_head, 1, 1) keys = xk values = xv From 7711dbef9a1ec07cb630f4e7af5972c143be4c38 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Sat, 2 Sep 2023 15:15:27 +0800 Subject: [PATCH 3/8] clean --- tests/test_infer_ops/triton/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_infer_ops/triton/utils.py b/tests/test_infer_ops/triton/utils.py index 2787006ba304..940d277cfb02 100644 --- a/tests/test_infer_ops/triton/utils.py +++ b/tests/test_infer_ops/triton/utils.py @@ -36,7 +36,6 @@ def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): xv = xv.view(bs, seqlen, num_head, head_dim) mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() mask[mask == 0.] = -100000000.0 - # mask[mask==1.0] = 0 mask = mask.repeat(bs, num_head, 1, 1) keys = xk values = xv From 186eea32c2f59f89362f822a80c418620ca806a2 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Sat, 2 Sep 2023 19:12:27 +0800 Subject: [PATCH 4/8] fix bugs --- tests/test_infer/test_bloom_infer.py | 6 ++++-- tests/test_infer/test_infer_engine.py | 8 ++++++-- tests/test_infer/test_kvcache_manager.py | 5 +++-- tests/test_infer/test_llama_infer.py | 6 ++++-- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index ab1df2e94e35..754c158e6279 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -1,5 +1,6 @@ import os import pytest +from packaging import version import torch import torch.distributed as dist from transformers import AutoModelForCausalLM, AutoTokenizer, BloomForCausalLM @@ -15,9 +16,10 @@ MAX_INPUT_LEN = 16 MAX_OUTPUT_LEN = 32 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') -def run(): +def run(): model_path = "/data3/data/model_eval_for_commerical_use/phoenix-inst-chat-7b" if os.path.isdir(model_path) is False: return @@ -52,7 +54,7 @@ def check_engine(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run() - +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index 6fcf9fe0f387..c9432509d941 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -1,6 +1,7 @@ +import pytest from itertools import accumulate +from packaging import version -import pytest import torch import torch.nn as nn from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM @@ -18,7 +19,9 @@ MAX_INPUT_LEN = 16 MAX_OUTPUT_LEN = 8 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") def test_prepare_data(): # dummy module used for testing class DummyModule(nn.Module): @@ -68,7 +71,7 @@ def __init__(self): assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc) assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc) - +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") def test_orig_generate(): input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN)) @@ -116,6 +119,7 @@ def check_engine(rank, world_size, port): run() +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py index fb04d7800ea2..f57c6956f817 100644 --- a/tests/test_infer/test_kvcache_manager.py +++ b/tests/test_infer/test_kvcache_manager.py @@ -1,5 +1,5 @@ import os - +from packaging import version import pytest import torch @@ -14,6 +14,7 @@ HEAD_NUM = 32 HEAD_DIM = 128 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim): os.environ['RANK'] = str(rank) @@ -42,7 +43,7 @@ def create_cache_manager(rank, world_size, port, batch_size, input_len, output_l kvcache_manager.alloc_contiguous(batch_size) assert torch.all(kvcache_manager.mem_state[:total_token_prefill + batch_size] == False) - +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() def test_cache_manager_dist(): diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index a87826719ecf..92446254b93a 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -1,7 +1,7 @@ import os - import pytest import torch +from packaging import version import numpy as np import torch.distributed as dist from transformers import LlamaForCausalLM, LlamaTokenizer @@ -20,6 +20,8 @@ MAX_INPUT_LEN = 12 MAX_OUTPUT_LEN = 100 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') + def init_to_get_rotary(self, base=10000): self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads if not hasattr(self.config, "rope_scaling"): @@ -81,7 +83,7 @@ def check_llama(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_llama_test() - +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() From 4fcce416068756bb6925c2b1538fdd33b82b685b Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Sun, 3 Sep 2023 11:03:22 +0800 Subject: [PATCH 5/8] add --- tests/test_infer/test_llama_infer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 92446254b93a..c8eb5c2e33e3 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -90,6 +90,5 @@ def check_llama(rank, world_size, port): def test_llama(): spawn(check_llama, TPSIZE) - if __name__ == "__main__": test_llama() From 3aa86055d4eebba637945367729f0ce397b70ac0 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Sun, 3 Sep 2023 17:19:34 +0800 Subject: [PATCH 6/8] fix llama non-vllm kernels bug --- .../tensor_parallel/modeling/llama.py | 63 ++++++++++++++----- .../tensor_parallel/policies/llama.py | 6 -- tests/test_infer/test_llama_infer.py | 4 +- 3 files changed, 50 insertions(+), 23 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 82f294163fd7..db95361ed900 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -8,7 +8,6 @@ LlamaDecoderLayer, LlamaModel, LlamaRMSNorm, - apply_rotary_pos_emb, ) from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState @@ -29,6 +28,29 @@ ) HAS_VLLM_KERNERL = False +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + return + class LlamaInferenceForwards: """ @@ -251,8 +273,9 @@ def llama_flash_attn_kvcache_forward( query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) key_states_transposed = key_states.transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + # NOTE might want to revise # need some way to record the length of past key values cache @@ -260,27 +283,37 @@ def llama_flash_attn_kvcache_forward( if infer_state.decode_layer_id == 0: # once per model.forward infer_state.cache_manager.past_key_values_length += q_len # seq_len + HAS_VLLM_KERNERL = False if HAS_VLLM_KERNERL: - cos, sin = infer_state.position_cos, infer_state.position_sin + # cos, sin = infer_state.position_cos, infer_state.position_sin + value_states_transposed = value_states.transpose(1, 2) + + cos, sin = self.rotary_emb(value_states_transposed, + seq_len=infer_state.cache_manager.past_key_values_length) + cos_sin_cache = torch.cat((cos, sin), dim=-1) rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache) key_states = key_states_transposed.transpose(1, 2) + key_states = key_states.reshape(-1, self.num_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_heads, self.head_dim) + query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) else: # NOTE: there are some issues for original rotary_embedding_neox of huggingface + value_states_transposed = value_states.transpose(1, 2) cos, sin = self.rotary_emb(value_states_transposed, - seq_len=infer_state.cache_manager.past_key_values_length) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, position_ids) - key_states = key_states_transposed.transpose(1, 2) - - def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): - copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) - copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) - return - - key_states = key_states.reshape(-1, self.num_heads, self.head_dim) - value_states = value_states.reshape(-1, self.num_heads, self.head_dim) - query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) + seq_len=infer_state.cache_manager.past_key_values_length) + + rotary_positions_ids = position_ids + idx = position_ids.shape[0] - 1 + if idx >= 1: + rotary_positions_ids = [[idx]] + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, rotary_positions_ids) + + query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_heads, self.head_dim) if infer_state.is_context_stage: # first token generation diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index bbd2156b8523..05e9fd7dc3ee 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -16,12 +16,6 @@ def module_policy(self): policy = super().module_policy() self.shard_config._infer() - # example for replace layer or decoder - # if self.shard_config.enable_flash_attention: - # policy[LlamaAttention] = ModulePolicyDescription(method_replacement={ - # 'forward': get_llama_flash_attention_forward(), - # }) - infer_forward = LlamaInferenceForwards.llama_model_forward method_replacement = {'forward': partial(infer_forward)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index c8eb5c2e33e3..3caa73ac23d4 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -15,7 +15,7 @@ os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -TPSIZE = 2 +TPSIZE = 1 BATCH_SIZE = 8 MAX_INPUT_LEN = 12 MAX_OUTPUT_LEN = 100 @@ -58,7 +58,7 @@ def run_llama_test(test_config): init_to_get_rotary(model.model, base=10000) model = model.half() - text = "how is weather today?" + text = "where is the location of the capital of france?" input_ids = tokenizer.encode(text, return_tensors='pt', device='cuda') infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) From a2a1346465f7112747c9cb16f422344254e7bba3 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Sun, 3 Sep 2023 17:20:42 +0800 Subject: [PATCH 7/8] modify --- colossalai/inference/tensor_parallel/modeling/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index db95361ed900..1e4717eccb18 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -283,7 +283,6 @@ def llama_flash_attn_kvcache_forward( if infer_state.decode_layer_id == 0: # once per model.forward infer_state.cache_manager.past_key_values_length += q_len # seq_len - HAS_VLLM_KERNERL = False if HAS_VLLM_KERNERL: # cos, sin = infer_state.position_cos, infer_state.position_sin value_states_transposed = value_states.transpose(1, 2) From 470243340da6ca9b6646c9d3c5128d00f6eb5932 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Sun, 3 Sep 2023 17:53:49 +0800 Subject: [PATCH 8/8] clean codes --- .../tensor_parallel/modeling/llama.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 1e4717eccb18..102422d6f97c 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -284,18 +284,25 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager.past_key_values_length += q_len # seq_len if HAS_VLLM_KERNERL: - # cos, sin = infer_state.position_cos, infer_state.position_sin - value_states_transposed = value_states.transpose(1, 2) + # NOTE: fix rotatry embedding precision problem + cos, sin = infer_state.position_cos, infer_state.position_sin + + # value_states_transposed = value_states.transpose(1, 2) - cos, sin = self.rotary_emb(value_states_transposed, - seq_len=infer_state.cache_manager.past_key_values_length) + # cos, sin = self.rotary_emb(value_states_transposed, + # seq_len=infer_state.cache_manager.past_key_values_length) cos_sin_cache = torch.cat((cos, sin), dim=-1) - rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache) - key_states = key_states_transposed.transpose(1, 2) + + key_states = key_states.view(-1, self.num_heads * self.head_dim) + query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads * self.head_dim) + rotary_embedding_neox(position_ids.squeeze(1), query_states, key_states, self.head_dim, cos_sin_cache) + + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) key_states = key_states.reshape(-1, self.num_heads, self.head_dim) value_states = value_states.reshape(-1, self.num_heads, self.head_dim) - query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) + else: # NOTE: there are some issues for original rotary_embedding_neox of huggingface