diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 986177555b06..c6abb74f080b 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -181,10 +181,11 @@ def prepare_batch_state(self, inputs) -> BatchInferState: max_len_in_batch = -1 if isinstance(inputs, (BatchEncoding, dict)): for i, attn_mask in enumerate(attention_mask): - if isinstance(attn_mask, torch.Tensor): - curr_seq_len = int(torch.sum(attn_mask)) - else: - curr_seq_len = int(sum(attn_mask)) + curr_seq_len = len(attn_mask) + # if isinstance(attn_mask, torch.Tensor): + # curr_seq_len = int(torch.sum(attn_mask)) + # else: + # curr_seq_len = int(sum(attn_mask)) seq_lengths[i] = curr_seq_len seq_start_indexes[i] = start_index start_index += curr_seq_len @@ -196,7 +197,6 @@ def prepare_batch_state(self, inputs) -> BatchInferState: seq_start_indexes[i] = start_index start_index += curr_seq_len max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch - block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda') batch_infer_state = BatchInferState(batch_size, max_len_in_batch) batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index cc0236a22e03..219cd1ae0d0e 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -3,16 +3,12 @@ import numpy as np import torch from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaModel, - LlamaRMSNorm, -) +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton.context_attention import llama_context_attn_fwd from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest +from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd try: @@ -28,24 +24,26 @@ ) HAS_VLLM_KERNERL = False + def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed + def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) @@ -75,17 +73,6 @@ def llama_model_forward( batch_size = input_ids.shape[0] # input_ids.shape[0] infer_state = self.infer_state - b_seq_len_numpy = infer_state.seq_len.cpu().numpy() - - if HAS_VLLM_KERNERL: - position_ids = torch.from_numpy( - np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0)).cuda() - - # this equals - infer_state.position_cos = torch.index_select(self._cos_cached, 0, - position_ids).view(position_ids.shape[0], -1) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, - position_ids).view(position_ids.shape[0], -1) return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -138,7 +125,6 @@ def llama_model_forward( # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index - if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange(past_key_values_length, @@ -149,6 +135,17 @@ def llama_model_forward( else: position_ids = position_ids.view(-1, seq_length).long() + if infer_state.is_context_stage: + + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1) + else: + seq_len = infer_state.seq_len + infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -271,11 +268,9 @@ def llama_flash_attn_kvcache_forward( # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head] # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - key_states_transposed = key_states.transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - # NOTE might want to revise # need some way to record the length of past key values cache @@ -283,38 +278,20 @@ def llama_flash_attn_kvcache_forward( if infer_state.decode_layer_id == 0: # once per model.forward infer_state.cache_manager.past_key_values_length += q_len # seq_len - if HAS_VLLM_KERNERL: - # NOTE: fix rotatry embedding precision problem - cos, sin = infer_state.position_cos, infer_state.position_sin - - # value_states_transposed = value_states.transpose(1, 2) - - # cos, sin = self.rotary_emb(value_states_transposed, - # seq_len=infer_state.cache_manager.past_key_values_length) - - cos_sin_cache = torch.cat((cos, sin), dim=-1) - - key_states = key_states.view(-1, self.num_heads * self.head_dim) - query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads * self.head_dim) - rotary_embedding_neox(position_ids.squeeze(1), query_states, key_states, self.head_dim, cos_sin_cache) - - - query_states = query_states.reshape(-1, self.num_heads, self.head_dim) - key_states = key_states.reshape(-1, self.num_heads, self.head_dim) - value_states = value_states.reshape(-1, self.num_heads, self.head_dim) - - else: - # NOTE: there are some issues for original rotary_embedding_neox of huggingface - - value_states_transposed = value_states.transpose(1, 2) - cos, sin = self.rotary_emb(value_states_transposed, - seq_len=infer_state.cache_manager.past_key_values_length) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, position_ids) - - query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) - key_states = key_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) - value_states = value_states.reshape(-1, self.num_heads, self.head_dim) + cos, sin = infer_state.position_cos, infer_state.position_sin + # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) + + rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) + rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) + + def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + return + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) + key_states = key_states.reshape(-1, self.num_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_heads, self.head_dim) if infer_state.is_context_stage: # first token generation @@ -364,6 +341,7 @@ def llama_flash_attn_kvcache_forward( def get_llama_vllm_rmsnorm_forward(): if HAS_VLLM_KERNERL: + def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): x = hidden_states out = torch.empty_like(x) @@ -379,4 +357,3 @@ def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): return _vllm_rmsnorm_forward else: return None - diff --git a/colossalai/kernel/triton/rotary_embedding_kernel.py b/colossalai/kernel/triton/rotary_embedding_kernel.py new file mode 100644 index 000000000000..d9d1b2bcf026 --- /dev/null +++ b/colossalai/kernel/triton/rotary_embedding_kernel.py @@ -0,0 +1,93 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm +import torch +import triton +import triton.language as tl + + +@triton.jit +def _rotary_kernel( + q, + Cos, + Sin, + q_bs_stride, + q_h_stride, + q_d_stride, + cos_bs_stride, + cos_d_stride, + total_len, + HEAD_NUM: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + current_head_index = tl.program_id(0) + current_seq_index = tl.program_id(1) + + current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + off_q0 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[ + None, :, None] * q_h_stride + dim_range0[None, None, :] * q_d_stride + off_q1 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[ + None, :, None] * q_h_stride + dim_range1[None, None, :] * q_d_stride + + off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride + + q0 = tl.load(q + off_q0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0) + q1 = tl.load(q + off_q1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0) + + cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + + out0 = q0 * cos - q1 * sin + out1 = q0 * sin + q1 * cos + + tl.store(q + off_q0, + out0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM)) + tl.store(q + off_q1, + out1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM)) + + return + + +@torch.no_grad() +def rotary_embedding_fwd(q, cos, sin): + total_len = q.shape[0] + head_num = q.shape[1] + head_dim = q.shape[2] + assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" + BLOCK_HEAD = 4 + BLOCK_SEQ = 32 + grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + _rotary_kernel[grid]( + q, + cos, + sin, + q.stride(0), + q.stride(1), + q.stride(2), + cos.stride(0), + cos.stride(1), + total_len, + HEAD_NUM=head_num, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + HEAD_DIM=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py new file mode 100644 index 000000000000..1aabd340aedd --- /dev/null +++ b/examples/inference/bench_llama.py @@ -0,0 +1,138 @@ +import os +import time + +import numpy as np +import pytest +import torch +import torch.distributed as dist +from torch.profiler import ProfilerActivity, profile, record_function +from transformers import LlamaForCausalLM, LlamaTokenizer + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +TPSIZE = 1 +BATCH_SIZE = 32 +MAX_INPUT_LEN = 1024 +MAX_OUTPUT_LEN = 256 + + +def init_to_get_rotary(self, base=10000): + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / + self.config.head_dim_)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + return + + +def print_perf_stats(latency_set, config, warmup=3): + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = getattr(config, "num_layers", config.num_hidden_layers) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 + + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * BATCH_SIZE / 1e12)) + + +@parameterize('test_config', [{ + 'tp_size': TPSIZE, +}]) +def run_llama_test(test_config): + + llama_model_path = "/data/scratch/llama-7b-hf" + tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) + tokenizer.pad_token_id = tokenizer.unk_token_id + model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) + init_to_get_rotary(model.model, base=10000) + model = model.half() + + model_config = model.config + + infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) + shardformer = ShardFormer(shard_config=shard_config) + + infer_engine.prepare_with_shard_config(shard_config) + infer_engine.shard_model_by(shardformer) + + batch_size = 2 + max_new_tokens = 128 + input_len = 1024 + + generate_kwargs = dict(max_new_tokens=max_new_tokens, do_sample=False) + input_tokens = { + "input_ids": torch.randint(1, 1000, (batch_size, input_len), device='cuda'), + "attention_mask": torch.ones((batch_size, input_len), device='cuda') + } + + iters = 10 + times = [] + + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print("generation time {} s".format(str(end - start))) + times.append((end - start) / (out_len - input_len)) + infer_engine.cache_manager.free_all() + + print("outputs, ", len(outputs)) + outputs = tokenizer.batch_decode(outputs) + + print_perf_stats(times, model_config) + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + with record_function("model_inference"): + torch.cuda.synchronize() + outputs = infer_engine.generate(input_tokens, generate_kwargs) + torch.cuda.synchronize() + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, TPSIZE) + + +if __name__ == "__main__": + test_llama() diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 3caa73ac23d4..1d043ba59338 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -1,18 +1,18 @@ import os + +import numpy as np import pytest import torch -from packaging import version -import numpy as np import torch.distributed as dist +from packaging import version from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai -from colossalai.logging import disable_existing_loggers -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.cluster import ProcessGroupMesh -from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.inference.tensor_parallel.engine import TPInferEngine - +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' TPSIZE = 1 @@ -22,6 +22,7 @@ CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') + def init_to_get_rotary(self, base=10000): self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads if not hasattr(self.config, "rope_scaling"): @@ -35,7 +36,8 @@ def init_to_get_rotary(self, base=10000): else: max_seq_len = 2048 * rope_scaling_factor base = float(base) - inv_freq = 1.0 / (base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)) + inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / + self.config.head_dim_)) t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) @@ -43,39 +45,42 @@ def init_to_get_rotary(self, base=10000): self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() return + @parameterize('test_config', [{ 'tp_size': TPSIZE, }]) def run_llama_test(test_config): - + llama_model_path = "/data/scratch/llama-7b-hf" if os.path.isdir(llama_model_path) is False: return - + tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) tokenizer.pad_token_id = tokenizer.unk_token_id model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) init_to_get_rotary(model.model, base=10000) model = model.half() - - text = "where is the location of the capital of france?" - input_ids = tokenizer.encode(text, return_tensors='pt', device='cuda') - + + text = ["how is weather today?", "i am "] + input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True, device='cuda') + + #print("input ids ", input_ids) infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) shardformer = ShardFormer(shard_config=shard_config) - + infer_engine.prepare_with_shard_config(shard_config) infer_engine.shard_model_by(shardformer) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) outputs = infer_engine.generate(input_ids, generate_kwargs) - print("outputs.shape: ", outputs.shape) - - print("outputs: ", outputs) + #print("outputs.shape: ", outputs.shape) - output_text = tokenizer.decode(outputs[0]) - print(output_text) + #print("outputs: ", outputs) + if not dist.is_initialized() or dist.get_rank() == 0: + for o in outputs: + output_text = tokenizer.decode(o) + #print(output_text) def check_llama(rank, world_size, port): @@ -83,6 +88,7 @@ def check_llama(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_llama_test() + @pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @@ -90,5 +96,6 @@ def check_llama(rank, world_size, port): def test_llama(): spawn(check_llama, TPSIZE) + if __name__ == "__main__": test_llama() diff --git a/tests/test_infer_ops/triton/test_layernorm.py b/tests/test_infer_ops/triton/test_layernorm_triton.py similarity index 100% rename from tests/test_infer_ops/triton/test_layernorm.py rename to tests/test_infer_ops/triton/test_layernorm_triton.py diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py new file mode 100644 index 000000000000..4413dba642b8 --- /dev/null +++ b/tests/test_infer_ops/triton/test_rotary_embedding.py @@ -0,0 +1,52 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm + +import time + +import pytest +import torch +from packaging import version + +try: + import triton + import triton.language as tl + + from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd + from tests.test_infer_ops.triton.utils import benchmark + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + + +def torch_rotary_emb(x, cos, sin): + seq_len, h, dim = x.shape + x0 = x[:, :, 0:dim // 2] + x1 = x[:, :, dim // 2:dim] + cos = cos.view((seq_len, 1, dim // 2)) + sin = sin.view((seq_len, 1, dim // 2)) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + return torch.cat((o0, o1), dim=-1) + + +def test_rotary_emb(): + SEQ_LEN = 1 + HEAD_NUM = 32 + HEAD_DIM = 128 + dtype = torch.half + # create data + x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + cos_shape = (SEQ_LEN, HEAD_DIM // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda') + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda') + # forward pass + y_torch = torch_rotary_emb(x, cos, sin) + rotary_embedding_fwd(x, cos, sin) + y_triton = x + # print("max delta:", torch.max(torch.abs(y_torch - y_triton))) + # compare + assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=1e-2)