diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 82f294163fd7..102422d6f97c 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -8,7 +8,6 @@ LlamaDecoderLayer, LlamaModel, LlamaRMSNorm, - apply_rotary_pos_emb, ) from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState @@ -29,6 +28,29 @@ ) HAS_VLLM_KERNERL = False +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + return + class LlamaInferenceForwards: """ @@ -251,8 +273,9 @@ def llama_flash_attn_kvcache_forward( query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) key_states_transposed = key_states.transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + # NOTE might want to revise # need some way to record the length of past key values cache @@ -261,26 +284,42 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager.past_key_values_length += q_len # seq_len if HAS_VLLM_KERNERL: + # NOTE: fix rotatry embedding precision problem cos, sin = infer_state.position_cos, infer_state.position_sin + + # value_states_transposed = value_states.transpose(1, 2) + + # cos, sin = self.rotary_emb(value_states_transposed, + # seq_len=infer_state.cache_manager.past_key_values_length) + cos_sin_cache = torch.cat((cos, sin), dim=-1) - rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache) - key_states = key_states_transposed.transpose(1, 2) + + key_states = key_states.view(-1, self.num_heads * self.head_dim) + query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads * self.head_dim) + rotary_embedding_neox(position_ids.squeeze(1), query_states, key_states, self.head_dim, cos_sin_cache) + + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) + key_states = key_states.reshape(-1, self.num_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_heads, self.head_dim) + else: # NOTE: there are some issues for original rotary_embedding_neox of huggingface + value_states_transposed = value_states.transpose(1, 2) cos, sin = self.rotary_emb(value_states_transposed, - seq_len=infer_state.cache_manager.past_key_values_length) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, position_ids) - key_states = key_states_transposed.transpose(1, 2) - - def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): - copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) - copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) - return - - key_states = key_states.reshape(-1, self.num_heads, self.head_dim) - value_states = value_states.reshape(-1, self.num_heads, self.head_dim) - query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) + seq_len=infer_state.cache_manager.past_key_values_length) + + rotary_positions_ids = position_ids + idx = position_ids.shape[0] - 1 + if idx >= 1: + rotary_positions_ids = [[idx]] + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, rotary_positions_ids) + + query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_heads, self.head_dim) if infer_state.is_context_stage: # first token generation diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index bbd2156b8523..05e9fd7dc3ee 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -16,12 +16,6 @@ def module_policy(self): policy = super().module_policy() self.shard_config._infer() - # example for replace layer or decoder - # if self.shard_config.enable_flash_attention: - # policy[LlamaAttention] = ModulePolicyDescription(method_replacement={ - # 'forward': get_llama_flash_attention_forward(), - # }) - infer_forward = LlamaInferenceForwards.llama_model_forward method_replacement = {'forward': partial(infer_forward)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index 95ab7d5c451e..754c158e6279 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -1,4 +1,6 @@ +import os import pytest +from packaging import version import torch import torch.distributed as dist from transformers import AutoModelForCausalLM, AutoTokenizer, BloomForCausalLM @@ -14,10 +16,14 @@ MAX_INPUT_LEN = 16 MAX_OUTPUT_LEN = 32 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') -def run(): +def run(): model_path = "/data3/data/model_eval_for_commerical_use/phoenix-inst-chat-7b" + if os.path.isdir(model_path) is False: + return + tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer.pad_token = tokenizer.eos_token @@ -48,7 +54,7 @@ def check_engine(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run() - +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index 6fcf9fe0f387..c9432509d941 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -1,6 +1,7 @@ +import pytest from itertools import accumulate +from packaging import version -import pytest import torch import torch.nn as nn from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM @@ -18,7 +19,9 @@ MAX_INPUT_LEN = 16 MAX_OUTPUT_LEN = 8 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") def test_prepare_data(): # dummy module used for testing class DummyModule(nn.Module): @@ -68,7 +71,7 @@ def __init__(self): assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc) assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc) - +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") def test_orig_generate(): input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN)) @@ -116,6 +119,7 @@ def check_engine(rank, world_size, port): run() +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py index fb04d7800ea2..f57c6956f817 100644 --- a/tests/test_infer/test_kvcache_manager.py +++ b/tests/test_infer/test_kvcache_manager.py @@ -1,5 +1,5 @@ import os - +from packaging import version import pytest import torch @@ -14,6 +14,7 @@ HEAD_NUM = 32 HEAD_DIM = 128 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim): os.environ['RANK'] = str(rank) @@ -42,7 +43,7 @@ def create_cache_manager(rank, world_size, port, batch_size, input_len, output_l kvcache_manager.alloc_contiguous(batch_size) assert torch.all(kvcache_manager.mem_state[:total_token_prefill + batch_size] == False) - +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() def test_cache_manager_dist(): diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index c8f852aef420..3caa73ac23d4 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -1,24 +1,26 @@ import os - -import numpy as np import pytest import torch +from packaging import version +import numpy as np import torch.distributed as dist from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.cluster import ProcessGroupMesh +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.inference.tensor_parallel.engine import TPInferEngine + os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -TPSIZE = 2 +TPSIZE = 1 BATCH_SIZE = 8 MAX_INPUT_LEN = 12 MAX_OUTPUT_LEN = 100 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') def init_to_get_rotary(self, base=10000): self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads @@ -33,8 +35,7 @@ def init_to_get_rotary(self, base=10000): else: max_seq_len = 2048 * rope_scaling_factor base = float(base) - inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / - self.config.head_dim_)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)) t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) @@ -42,33 +43,35 @@ def init_to_get_rotary(self, base=10000): self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() return - @parameterize('test_config', [{ 'tp_size': TPSIZE, }]) def run_llama_test(test_config): - + llama_model_path = "/data/scratch/llama-7b-hf" + if os.path.isdir(llama_model_path) is False: + return + tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) tokenizer.pad_token_id = tokenizer.unk_token_id model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) init_to_get_rotary(model.model, base=10000) model = model.half() - - text = "how is weather today?" + + text = "where is the location of the capital of france?" input_ids = tokenizer.encode(text, return_tensors='pt', device='cuda') - + infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) shardformer = ShardFormer(shard_config=shard_config) - + infer_engine.prepare_with_shard_config(shard_config) infer_engine.shard_model_by(shardformer) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) outputs = infer_engine.generate(input_ids, generate_kwargs) print("outputs.shape: ", outputs.shape) - + print("outputs: ", outputs) output_text = tokenizer.decode(outputs[0]) @@ -80,13 +83,12 @@ def check_llama(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_llama_test() - +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_llama(): spawn(check_llama, TPSIZE) - if __name__ == "__main__": test_llama() diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py index e7446b289acd..b0fac1263047 100644 --- a/tests/test_infer_ops/triton/test_llama_context_attention.py +++ b/tests/test_infer_ops/triton/test_llama_context_attention.py @@ -44,7 +44,7 @@ def test_llama_context_attention(): torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched" + assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3), "outputs from triton and torch are not matched" latency_1 = benchmark(llama_context_attn_fwd, query, k, v, o, b_start, b_len, max_input_len) latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim)