From c7d6d8d9a49633ea87f0146d18a330ff74d30d68 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Mon, 4 Sep 2023 17:06:25 +0800 Subject: [PATCH 01/12] fix rotary embedding --- .../tensor_parallel/modeling/llama.py | 49 +++++----- .../kernel/triton/rotary_embedding_kernel.py | 93 +++++++++++++++++++ tests/test_infer/test_llama_infer.py | 12 ++- .../triton/test_rotary_embedding.py | 51 ++++++++++ 4 files changed, 174 insertions(+), 31 deletions(-) create mode 100644 colossalai/kernel/triton/rotary_embedding_kernel.py create mode 100644 tests/test_infer_ops/triton/test_rotary_embedding.py diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 82f294163fd7..ba72226f8941 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -14,6 +14,7 @@ from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton.context_attention import llama_context_attn_fwd from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest +from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd try: @@ -53,17 +54,6 @@ def llama_model_forward( batch_size = input_ids.shape[0] # input_ids.shape[0] infer_state = self.infer_state - b_seq_len_numpy = infer_state.seq_len.cpu().numpy() - - if HAS_VLLM_KERNERL: - position_ids = torch.from_numpy( - np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0)).cuda() - - # this equals - infer_state.position_cos = torch.index_select(self._cos_cached, 0, - position_ids).view(position_ids.shape[0], -1) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, - position_ids).view(position_ids.shape[0], -1) return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -117,6 +107,21 @@ def llama_model_forward( # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + if infer_state.is_context_stage: + b_seq_len_numpy = infer_state.seq_len.cpu().numpy() + position_ids = torch.from_numpy( + np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0)).cuda() + + infer_state.position_cos = torch.index_select(self._cos_cached, 0, + position_ids).view(position_ids.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, + position_ids).view(position_ids.shape[0], -1) + position_ids = None + else: + seq_len = infer_state.seq_len + infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange(past_key_values_length, @@ -249,10 +254,9 @@ def llama_flash_attn_kvcache_forward( # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head] # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - key_states_transposed = key_states.transpose(1, 2) # NOTE might want to revise # need some way to record the length of past key values cache @@ -260,27 +264,20 @@ def llama_flash_attn_kvcache_forward( if infer_state.decode_layer_id == 0: # once per model.forward infer_state.cache_manager.past_key_values_length += q_len # seq_len - if HAS_VLLM_KERNERL: - cos, sin = infer_state.position_cos, infer_state.position_sin - cos_sin_cache = torch.cat((cos, sin), dim=-1) - rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache) - key_states = key_states_transposed.transpose(1, 2) - else: - # NOTE: there are some issues for original rotary_embedding_neox of huggingface - value_states_transposed = value_states.transpose(1, 2) - cos, sin = self.rotary_emb(value_states_transposed, - seq_len=infer_state.cache_manager.past_key_values_length) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, position_ids) - key_states = key_states_transposed.transpose(1, 2) + cos, sin = infer_state.position_cos, infer_state.position_sin + # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) + + rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) + rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) return + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) key_states = key_states.reshape(-1, self.num_heads, self.head_dim) value_states = value_states.reshape(-1, self.num_heads, self.head_dim) - query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) if infer_state.is_context_stage: # first token generation diff --git a/colossalai/kernel/triton/rotary_embedding_kernel.py b/colossalai/kernel/triton/rotary_embedding_kernel.py new file mode 100644 index 000000000000..d9d1b2bcf026 --- /dev/null +++ b/colossalai/kernel/triton/rotary_embedding_kernel.py @@ -0,0 +1,93 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm +import torch +import triton +import triton.language as tl + + +@triton.jit +def _rotary_kernel( + q, + Cos, + Sin, + q_bs_stride, + q_h_stride, + q_d_stride, + cos_bs_stride, + cos_d_stride, + total_len, + HEAD_NUM: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + current_head_index = tl.program_id(0) + current_seq_index = tl.program_id(1) + + current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + off_q0 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[ + None, :, None] * q_h_stride + dim_range0[None, None, :] * q_d_stride + off_q1 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[ + None, :, None] * q_h_stride + dim_range1[None, None, :] * q_d_stride + + off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride + + q0 = tl.load(q + off_q0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0) + q1 = tl.load(q + off_q1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0) + + cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + + out0 = q0 * cos - q1 * sin + out1 = q0 * sin + q1 * cos + + tl.store(q + off_q0, + out0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM)) + tl.store(q + off_q1, + out1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM)) + + return + + +@torch.no_grad() +def rotary_embedding_fwd(q, cos, sin): + total_len = q.shape[0] + head_num = q.shape[1] + head_dim = q.shape[2] + assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" + BLOCK_HEAD = 4 + BLOCK_SEQ = 32 + grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + _rotary_kernel[grid]( + q, + cos, + sin, + q.stride(0), + q.stride(1), + q.stride(2), + cos.stride(0), + cos.stride(1), + total_len, + HEAD_NUM=head_num, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + HEAD_DIM=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index c8f852aef420..101cb019b462 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -55,9 +55,10 @@ def run_llama_test(test_config): init_to_get_rotary(model.model, base=10000) model = model.half() - text = "how is weather today?" - input_ids = tokenizer.encode(text, return_tensors='pt', device='cuda') + text = ["how is weather today?", "i am "] + input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True, device='cuda') + print("input ids ", input_ids) infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) shardformer = ShardFormer(shard_config=shard_config) @@ -70,9 +71,10 @@ def run_llama_test(test_config): print("outputs.shape: ", outputs.shape) print("outputs: ", outputs) - - output_text = tokenizer.decode(outputs[0]) - print(output_text) + if not dist.is_initialized() or dist.get_rank() == 0: + for o in outputs: + output_text = tokenizer.decode(o) + print(output_text) def check_llama(rank, world_size, port): diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py new file mode 100644 index 000000000000..b8637b7965ef --- /dev/null +++ b/tests/test_infer_ops/triton/test_rotary_embedding.py @@ -0,0 +1,51 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm + +import time + +import pytest +import torch +from packaging import version + +try: + import triton + import triton.language as tl + + from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + + +def torch_rotary_emb(x, cos, sin): + seq_len, h, dim = x.shape + x0 = x[:, :, 0:dim // 2] + x1 = x[:, :, dim // 2:dim] + cos = cos.view((seq_len, 1, dim // 2)) + sin = sin.view((seq_len, 1, dim // 2)) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + return torch.cat((o0, o1), dim=-1) + + +def test_rotary_emb(): + SEQ_LEN = 512 + HEAD_NUM = 16 + HEAD_DIM = 64 + dtype = torch.half + # create data + x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + cos_shape = (SEQ_LEN, HEAD_DIM // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda') + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda') + # forward pass + y_torch = torch_rotary_emb(x, cos, sin) + rotary_embedding_fwd(x, cos, sin) + y_triton = x + + # compare + print("max delta:", torch.max(torch.abs(y_torch - y_triton))) + assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=1e-2) From 64c0782bf7e918164d6873407bca4c6c428d8199 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 5 Sep 2023 13:41:07 +0800 Subject: [PATCH 02/12] delete print --- tests/test_infer/test_llama_infer.py | 8 ++++---- .../triton/test_rotary_embedding.py | 16 +++++++++++----- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index b098375768fa..1d043ba59338 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -64,7 +64,7 @@ def run_llama_test(test_config): text = ["how is weather today?", "i am "] input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True, device='cuda') - print("input ids ", input_ids) + #print("input ids ", input_ids) infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) shardformer = ShardFormer(shard_config=shard_config) @@ -74,13 +74,13 @@ def run_llama_test(test_config): generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) outputs = infer_engine.generate(input_ids, generate_kwargs) - print("outputs.shape: ", outputs.shape) + #print("outputs.shape: ", outputs.shape) - print("outputs: ", outputs) + #print("outputs: ", outputs) if not dist.is_initialized() or dist.get_rank() == 0: for o in outputs: output_text = tokenizer.decode(o) - print(output_text) + #print(output_text) def check_llama(rank, world_size, port): diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py index b8637b7965ef..3a82a6603318 100644 --- a/tests/test_infer_ops/triton/test_rotary_embedding.py +++ b/tests/test_infer_ops/triton/test_rotary_embedding.py @@ -11,6 +11,8 @@ import triton.language as tl from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd + from tests.test_infer_ops.triton.utils import benchmark + HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -31,9 +33,9 @@ def torch_rotary_emb(x, cos, sin): def test_rotary_emb(): - SEQ_LEN = 512 - HEAD_NUM = 16 - HEAD_DIM = 64 + SEQ_LEN = 1 + HEAD_NUM = 32 + HEAD_DIM = 128 dtype = torch.half # create data x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM) @@ -45,7 +47,11 @@ def test_rotary_emb(): y_torch = torch_rotary_emb(x, cos, sin) rotary_embedding_fwd(x, cos, sin) y_triton = x - + # print("max delta:", torch.max(torch.abs(y_torch - y_triton))) # compare - print("max delta:", torch.max(torch.abs(y_torch - y_triton))) assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=1e-2) + + # triton_latency = benchmark(rotary_embedding_fwd, x, cos, sin) + # torch_latency = benchmark(torch_rotary_emb, x, cos, sin) + # print("triton kernel latency:{:.6f} ms".format(triton_latency)) + # print("torch kernel latency:{:.6f} ms".format(torch_latency)) From 64d8d557e4c9a99902816421747ec86b534f4d03 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 5 Sep 2023 13:49:11 +0800 Subject: [PATCH 03/12] fix init seq len bug --- colossalai/inference/tensor_parallel/engine.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 986177555b06..c6abb74f080b 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -181,10 +181,11 @@ def prepare_batch_state(self, inputs) -> BatchInferState: max_len_in_batch = -1 if isinstance(inputs, (BatchEncoding, dict)): for i, attn_mask in enumerate(attention_mask): - if isinstance(attn_mask, torch.Tensor): - curr_seq_len = int(torch.sum(attn_mask)) - else: - curr_seq_len = int(sum(attn_mask)) + curr_seq_len = len(attn_mask) + # if isinstance(attn_mask, torch.Tensor): + # curr_seq_len = int(torch.sum(attn_mask)) + # else: + # curr_seq_len = int(sum(attn_mask)) seq_lengths[i] = curr_seq_len seq_start_indexes[i] = start_index start_index += curr_seq_len @@ -196,7 +197,6 @@ def prepare_batch_state(self, inputs) -> BatchInferState: seq_start_indexes[i] = start_index start_index += curr_seq_len max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch - block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda') batch_infer_state = BatchInferState(batch_size, max_len_in_batch) batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device From a255925b7fa32bc3099ddec2cc54389bfe8e1990 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 5 Sep 2023 14:10:06 +0800 Subject: [PATCH 04/12] rename pytest --- .../triton/{test_layernorm.py => test_layernorm_triton.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/test_infer_ops/triton/{test_layernorm.py => test_layernorm_triton.py} (100%) diff --git a/tests/test_infer_ops/triton/test_layernorm.py b/tests/test_infer_ops/triton/test_layernorm_triton.py similarity index 100% rename from tests/test_infer_ops/triton/test_layernorm.py rename to tests/test_infer_ops/triton/test_layernorm_triton.py From c525e78fbf58a29b2d16015155db16b41c095ad8 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 5 Sep 2023 14:10:34 +0800 Subject: [PATCH 05/12] add benchmark for llama --- examples/inference/bench_llama.py | 153 ++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 examples/inference/bench_llama.py diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py new file mode 100644 index 000000000000..7bb1d11259d7 --- /dev/null +++ b/examples/inference/bench_llama.py @@ -0,0 +1,153 @@ +import os +import time + +import numpy as np +import pytest +import torch +import torch.distributed as dist +from torch.profiler import ProfilerActivity, profile, record_function +from transformers import LlamaForCausalLM, LlamaTokenizer + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +TPSIZE = 1 +BATCH_SIZE = 32 +MAX_INPUT_LEN = 1024 +MAX_OUTPUT_LEN = 256 + + +def init_to_get_rotary(self, base=10000): + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / + self.config.head_dim_)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + return + + +def print_perf_stats(latency_set, config, warmup=3): + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = getattr(config, "num_layers", config.num_hidden_layers) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 + + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * BATCH_SIZE / 1e12)) + + +@parameterize('test_config', [{ + 'tp_size': TPSIZE, +}]) +def run_llama_test(test_config): + + llama_model_path = "/data/scratch/llama-7b-hf" + tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) + tokenizer.pad_token_id = tokenizer.unk_token_id + model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) + init_to_get_rotary(model.model, base=10000) + model = model.half() + input = "" + for i in range(1022): + input += "a " + model_config = model.config + text = [input] + + batch_size = 32 + input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True, device='cuda') + + print("input ids ", input_ids) + infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) + shardformer = ShardFormer(shard_config=shard_config) + + infer_engine.prepare_with_shard_config(shard_config) + infer_engine.shard_model_by(shardformer) + print("input.shape: ", input_ids["input_ids"].shape) + + max_new_tokens = 128 + generate_kwargs = dict(max_new_tokens=max_new_tokens, do_sample=False) + outputs = infer_engine.generate(input_ids, generate_kwargs) + infer_engine.cache_manager.free_all() + print("outputs.shape: ", outputs.shape) + + input_len = 1024 + input_tokens = { + "input_ids": torch.randint(1, 1000, (batch_size, input_len), device='cuda'), + "attention_mask": torch.ones((batch_size, input_len), device='cuda') + } + input_tokens["input_ids"][:, 0] = 0 + # print("inputs ", input_tokens) + # print(input_tokens["input_ids"].shape) + input_len = input_tokens["input_ids"].shape[1] + + iters = 10 + + times = [] + + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print("generation time {} s".format(str(end - start))) + times.append((end - start) / (out_len - input_len)) + infer_engine.cache_manager.free_all() + + print("outputs, ", len(outputs)) + outputs = tokenizer.batch_decode(outputs) + + print_perf_stats(times, model_config) + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + with record_function("model_inference"): + torch.cuda.synchronize() + outputs = infer_engine.generate(input_tokens, generate_kwargs) + torch.cuda.synchronize() + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, TPSIZE) + + +if __name__ == "__main__": + test_llama() From 117cdf17b32d1e47a62ce1fe81a501ea681411ae Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 5 Sep 2023 16:40:13 +0800 Subject: [PATCH 06/12] refactor codes --- .../tensor_parallel/modeling/llama.py | 27 ++++++++----------- examples/inference/bench_llama.py | 23 +++------------- 2 files changed, 15 insertions(+), 35 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index ab18d4a97ab7..219cd1ae0d0e 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -125,22 +125,6 @@ def llama_model_forward( # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index - - if infer_state.is_context_stage: - b_seq_len_numpy = infer_state.seq_len.cpu().numpy() - position_ids = torch.from_numpy( - np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0)).cuda() - - infer_state.position_cos = torch.index_select(self._cos_cached, 0, - position_ids).view(position_ids.shape[0], -1) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, - position_ids).view(position_ids.shape[0], -1) - position_ids = None - else: - seq_len = infer_state.seq_len - infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange(past_key_values_length, @@ -151,6 +135,17 @@ def llama_model_forward( else: position_ids = position_ids.view(-1, seq_length).long() + if infer_state.is_context_stage: + + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1) + else: + seq_len = infer_state.seq_len + infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index 7bb1d11259d7..1aabd340aedd 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -74,42 +74,27 @@ def run_llama_test(test_config): model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) init_to_get_rotary(model.model, base=10000) model = model.half() - input = "" - for i in range(1022): - input += "a " - model_config = model.config - text = [input] - batch_size = 32 - input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True, device='cuda') + model_config = model.config - print("input ids ", input_ids) infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) shardformer = ShardFormer(shard_config=shard_config) infer_engine.prepare_with_shard_config(shard_config) infer_engine.shard_model_by(shardformer) - print("input.shape: ", input_ids["input_ids"].shape) + batch_size = 2 max_new_tokens = 128 - generate_kwargs = dict(max_new_tokens=max_new_tokens, do_sample=False) - outputs = infer_engine.generate(input_ids, generate_kwargs) - infer_engine.cache_manager.free_all() - print("outputs.shape: ", outputs.shape) - input_len = 1024 + + generate_kwargs = dict(max_new_tokens=max_new_tokens, do_sample=False) input_tokens = { "input_ids": torch.randint(1, 1000, (batch_size, input_len), device='cuda'), "attention_mask": torch.ones((batch_size, input_len), device='cuda') } - input_tokens["input_ids"][:, 0] = 0 - # print("inputs ", input_tokens) - # print(input_tokens["input_ids"].shape) - input_len = input_tokens["input_ids"].shape[1] iters = 10 - times = [] for i in range(iters): From 2af8002dfa93696ed9eda8f793327b9e17513369 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 5 Sep 2023 16:57:51 +0800 Subject: [PATCH 07/12] delete useless code --- tests/test_infer_ops/triton/test_rotary_embedding.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py index 3a82a6603318..4413dba642b8 100644 --- a/tests/test_infer_ops/triton/test_rotary_embedding.py +++ b/tests/test_infer_ops/triton/test_rotary_embedding.py @@ -50,8 +50,3 @@ def test_rotary_emb(): # print("max delta:", torch.max(torch.abs(y_torch - y_triton))) # compare assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=1e-2) - - # triton_latency = benchmark(rotary_embedding_fwd, x, cos, sin) - # torch_latency = benchmark(torch_rotary_emb, x, cos, sin) - # print("triton kernel latency:{:.6f} ms".format(triton_latency)) - # print("torch kernel latency:{:.6f} ms".format(torch_latency)) From f3e304695c447224a287fb0fadff5aca970b41f7 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 6 Sep 2023 09:10:20 +0800 Subject: [PATCH 08/12] add check triton and cuda --- tests/test_infer_ops/triton/test_rotary_embedding.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py index 4413dba642b8..9fafd480a956 100644 --- a/tests/test_infer_ops/triton/test_rotary_embedding.py +++ b/tests/test_infer_ops/triton/test_rotary_embedding.py @@ -32,6 +32,8 @@ def torch_rotary_emb(x, cos, sin): return torch.cat((o0, o1), dim=-1) +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, + reason="triton requires cuda version to be higher than 11.4") def test_rotary_emb(): SEQ_LEN = 1 HEAD_NUM = 32 @@ -50,3 +52,7 @@ def test_rotary_emb(): # print("max delta:", torch.max(torch.abs(y_torch - y_triton))) # compare assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + test_rotary_emb() From d71765a0ef431be0c0cca461c1fa4b4b347f5c8d Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 6 Sep 2023 17:03:59 +0800 Subject: [PATCH 09/12] delete benchmark and fix infer bugs --- tests/test_infer/test_bloom_infer.py | 2 +- tests/test_infer/test_llama_infer.py | 2 +- .../test_infer_ops/triton/test_rotary_embedding.py | 2 +- tests/test_infer_ops/triton/test_token_attn_1.py | 14 ++------------ 4 files changed, 5 insertions(+), 15 deletions(-) diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index dad3f9cb295f..4036b5a85918 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -43,7 +43,7 @@ def run(): infer_engine.shard_model_by(shardformer) generate_kwargs = dict(do_sample=False) - outputs = infer_engine.generate(input_ids, generate_kwargs) + outputs = infer_engine.generate(input_ids, **generate_kwargs) if not dist.is_initialized() or dist.get_rank() == 0: output_text = tokenizer.decode(outputs[0]) diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 1d043ba59338..30c8e2b45185 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -73,7 +73,7 @@ def run_llama_test(test_config): infer_engine.shard_model_by(shardformer) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) - outputs = infer_engine.generate(input_ids, generate_kwargs) + outputs = infer_engine.generate(input_ids, **generate_kwargs) #print("outputs.shape: ", outputs.shape) #print("outputs: ", outputs) diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py index f9457c1a04f7..92a85934d565 100644 --- a/tests/test_infer_ops/triton/test_rotary_embedding.py +++ b/tests/test_infer_ops/triton/test_rotary_embedding.py @@ -50,7 +50,7 @@ def test_rotary_emb(): rotary_embedding_fwd(x, cos, sin) y_triton = x # compare - assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=1e-2) + assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=0) if __name__ == "__main__": diff --git a/tests/test_infer_ops/triton/test_token_attn_1.py b/tests/test_infer_ops/triton/test_token_attn_1.py index d01685e7788f..aee7944597dc 100644 --- a/tests/test_infer_ops/triton/test_token_attn_1.py +++ b/tests/test_infer_ops/triton/test_token_attn_1.py @@ -59,18 +59,7 @@ def test_attn_1(): kv_cache_seq_len[i] = seq_len b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") - # Warm up - for _ in range(10): - token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) - - run_iter = 1000 - torch.cuda.synchronize() - t1 = time.time() - for _ in range(run_iter): - token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) - torch.cuda.synchronize() - t2 = time.time() - print("Time cost {}".format((t2 - t1) / run_iter)) + token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze() o = attn_out.squeeze() @@ -78,5 +67,6 @@ def test_attn_1(): print("mean ", torch.mean(torch.abs(torch_out - o))) assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) + if __name__ == "__main__": test_attn_1() From 6102ab28e062f33b604440cac7ddcb5ce03d1085 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 6 Sep 2023 17:07:15 +0800 Subject: [PATCH 10/12] delete benchmark for tests --- tests/test_infer_ops/triton/test_token_attn_2.py | 13 ++----------- tests/test_infer_ops/triton/test_token_attn_fwd.py | 10 ---------- 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/tests/test_infer_ops/triton/test_token_attn_2.py b/tests/test_infer_ops/triton/test_token_attn_2.py index 36b517c4aa3b..f834fedbb0f1 100644 --- a/tests/test_infer_ops/triton/test_token_attn_2.py +++ b/tests/test_infer_ops/triton/test_token_attn_2.py @@ -48,17 +48,8 @@ def test_token_attn_2(): kv_cache_seq_len[i] = seq_len kv_cache_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") - # Warm up - for _ in range(10): - token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) - run_iter = 1000 - torch.cuda.synchronize() - t1 = time.time() - for _ in range(run_iter): - token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) - torch.cuda.synchronize() - t2 = time.time() - print("Time cost {}".format((t2 - t1) / run_iter)) + token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) + torch_out = torch_attn(V, Prob, batch_size, seq_len, head_num, head_dim).squeeze() o = attn_out print("max ", torch.max(torch.abs(torch_out - o))) diff --git a/tests/test_infer_ops/triton/test_token_attn_fwd.py b/tests/test_infer_ops/triton/test_token_attn_fwd.py index e765ed4a3415..228d436515f9 100644 --- a/tests/test_infer_ops/triton/test_token_attn_fwd.py +++ b/tests/test_infer_ops/triton/test_token_attn_fwd.py @@ -56,18 +56,8 @@ def test(): kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda") token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi) - torch.cuda.synchronize() - start = time.time() - token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi) - torch.cuda.synchronize() - print("cost time:", (time.time() - start) * 1000) - torch_att(q, k, v, Z, seq_len, head_num, head_dim) - torch.cuda.synchronize() - start = time.time() torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim) - torch.cuda.synchronize() - print("cost time:", (time.time() - start) * 1000) print("max ", torch.max(torch.abs(torch_out - o))) print("mean ", torch.mean(torch.abs(torch_out - o))) From 38dd40136fa131b0f44def003ba69f0c5a9cf523 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 6 Sep 2023 17:10:47 +0800 Subject: [PATCH 11/12] delete useless code --- tests/test_infer_ops/triton/test_token_attn_fwd.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_infer_ops/triton/test_token_attn_fwd.py b/tests/test_infer_ops/triton/test_token_attn_fwd.py index 228d436515f9..e82318965e05 100644 --- a/tests/test_infer_ops/triton/test_token_attn_fwd.py +++ b/tests/test_infer_ops/triton/test_token_attn_fwd.py @@ -56,7 +56,6 @@ def test(): kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda") token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi) - torch_att(q, k, v, Z, seq_len, head_num, head_dim) torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim) print("max ", torch.max(torch.abs(torch_out - o))) From d26f5e911361381d7284e263fb31ad57503bb50c Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 6 Sep 2023 17:25:06 +0800 Subject: [PATCH 12/12] delete bechmark function in utils --- .../triton/test_bloom_context_attention.py | 40 ++++++++++-------- .../triton/test_copy_kv_dest.py | 25 +++++------ .../triton/test_layernorm_triton.py | 1 - .../triton/test_llama_context_attention.py | 41 ++++++++++--------- .../triton/test_rotary_embedding.py | 1 - tests/test_infer_ops/triton/utils.py | 30 ++------------ 6 files changed, 61 insertions(+), 77 deletions(-) diff --git a/tests/test_infer_ops/triton/test_bloom_context_attention.py b/tests/test_infer_ops/triton/test_bloom_context_attention.py index ea89d6bb4764..7447c85c5887 100644 --- a/tests/test_infer_ops/triton/test_bloom_context_attention.py +++ b/tests/test_infer_ops/triton/test_bloom_context_attention.py @@ -1,16 +1,17 @@ -import pytest import math -from packaging import version +import pytest import torch +from packaging import version from torch import nn from torch.nn import functional as F try: import triton import triton.language as tl - from tests.test_infer_ops.triton.utils import benchmark, torch_context_attention + from colossalai.kernel.triton import bloom_context_attn_fwd + from tests.test_infer_ops.triton.utils import torch_context_attention HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -18,33 +19,36 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, + reason="triton requires cuda version to be higher than 11.4") def test_bloom_context_attention(): bs = 4 head_num = 8 seq_len = 1024 head_dim = 64 - - query = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - k = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - v = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - + query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + max_input_len = seq_len - b_start = torch.zeros((bs, ), device="cuda", dtype=torch.int32) - b_len = torch.zeros((bs, ), device="cuda", dtype=torch.int32) - + b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32) + b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32) + for i in range(bs): b_start[i] = i * seq_len b_len[i] = seq_len - - o = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + + o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda") bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len, alibi) - + torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - - assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched" + + assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, + atol=1e-2), "outputs from triton and torch are not matched" + if __name__ == "__main__": - test_bloom_context_attention() \ No newline at end of file + test_bloom_context_attention() diff --git a/tests/test_infer_ops/triton/test_copy_kv_dest.py b/tests/test_infer_ops/triton/test_copy_kv_dest.py index 188493eb13ce..c656f81d2790 100644 --- a/tests/test_infer_ops/triton/test_copy_kv_dest.py +++ b/tests/test_infer_ops/triton/test_copy_kv_dest.py @@ -1,13 +1,12 @@ import pytest -from packaging import version - import torch +from packaging import version from torch import nn try: import triton import triton.language as tl - from tests.test_kernels.triton.utils import benchmark + from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest HAS_TRITON = True except ImportError: @@ -16,23 +15,25 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, + reason="triton requires cuda version to be higher than 11.4") def test_kv_cache_copy_op(): - + B_NTX = 32 * 2048 head_num = 8 head_dim = 64 - + cache = torch.randn((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16) dest_index = torch.arange(0, B_NTX, device="cuda", dtype=torch.int32) - + dest_data = torch.ones((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16) - + copy_kv_cache_to_dest(cache, dest_index, dest_data) - - assert torch.allclose(cache.cpu(), dest_data.cpu(), rtol=1e-3, atol=1e-3), "copy_kv_cache_to_dest outputs from triton and torch are not matched" - + + assert torch.allclose(cache.cpu(), dest_data.cpu(), rtol=1e-3, + atol=1e-3), "copy_kv_cache_to_dest outputs from triton and torch are not matched" + if __name__ == "__main__": test_kv_cache_copy_op() - diff --git a/tests/test_infer_ops/triton/test_layernorm_triton.py b/tests/test_infer_ops/triton/test_layernorm_triton.py index 9648f91e2f28..94cd704ffeba 100644 --- a/tests/test_infer_ops/triton/test_layernorm_triton.py +++ b/tests/test_infer_ops/triton/test_layernorm_triton.py @@ -4,7 +4,6 @@ from colossalai.kernel.triton import layer_norm from colossalai.testing.utils import parameterize -from tests.test_infer_ops.triton.utils import benchmark try: import triton diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py index 4c49c0b51333..1659fdde8f7f 100644 --- a/tests/test_infer_ops/triton/test_llama_context_attention.py +++ b/tests/test_infer_ops/triton/test_llama_context_attention.py @@ -1,16 +1,17 @@ -import pytest import math -from packaging import version +import pytest import torch +from packaging import version from torch import nn from torch.nn import functional as F try: import triton import triton.language as tl - from tests.test_infer_ops.triton.utils import benchmark, torch_context_attention + from colossalai.kernel.triton import llama_context_attn_fwd + from tests.test_infer_ops.triton.utils import torch_context_attention HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -19,32 +20,34 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, + reason="triton requires cuda version to be higher than 11.4") def test_llama_context_attention(): bs = 4 head_num = 8 seq_len = 1024 head_dim = 64 - - query = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - k = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - v = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - + query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + max_input_len = seq_len - b_start = torch.zeros((bs, ), device="cuda", dtype=torch.int32) - b_len = torch.zeros((bs, ), device="cuda", dtype=torch.int32) - + b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32) + b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32) + for i in range(bs): b_start[i] = i * seq_len b_len[i] = seq_len - - o = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + + o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len) - + torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - - assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3), "outputs from triton and torch are not matched" - + + assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, + atol=1e-3), "outputs from triton and torch are not matched" + + if __name__ == "__main__": - test_llama_context_attention() \ No newline at end of file + test_llama_context_attention() diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py index 92a85934d565..d5ecdf684538 100644 --- a/tests/test_infer_ops/triton/test_rotary_embedding.py +++ b/tests/test_infer_ops/triton/test_rotary_embedding.py @@ -11,7 +11,6 @@ import triton.language as tl from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd - from tests.test_infer_ops.triton.utils import benchmark HAS_TRITON = True except ImportError: diff --git a/tests/test_infer_ops/triton/utils.py b/tests/test_infer_ops/triton/utils.py index 940d277cfb02..b081b32b9ad3 100644 --- a/tests/test_infer_ops/triton/utils.py +++ b/tests/test_infer_ops/triton/utils.py @@ -1,32 +1,10 @@ -import numpy as np import math +import numpy as np import torch from torch.nn import functional as F -def benchmark(func, *args): - starter, ender = torch.cuda.Event( - enable_timing=True), torch.cuda.Event(enable_timing=True) - repetitions = 300 - - for i in range(10): - func(*args) - - timings = np.zeros((repetitions, 1)) - with torch.no_grad(): - for rep in range(repetitions): - starter.record() - func(*args) - ender.record() - # WAIT FOR GPU SYNC - torch.cuda.synchronize() - curr_time = starter.elapsed_time(ender) - timings[rep] = curr_time - - mean_syn = np.sum(timings) / repetitions - return mean_syn - def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): ''' adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 @@ -42,9 +20,9 @@ def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): xq = xq.transpose(1, 2) keys = keys.transpose(1, 2) values = values.transpose(1, 2) - sm_scale = 1/math.sqrt(head_dim) + sm_scale = 1 / math.sqrt(head_dim) scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16) - + output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) - return output \ No newline at end of file + return output