diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index dad3f9cb295f..4036b5a85918 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -43,7 +43,7 @@ def run(): infer_engine.shard_model_by(shardformer) generate_kwargs = dict(do_sample=False) - outputs = infer_engine.generate(input_ids, generate_kwargs) + outputs = infer_engine.generate(input_ids, **generate_kwargs) if not dist.is_initialized() or dist.get_rank() == 0: output_text = tokenizer.decode(outputs[0]) diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 1d043ba59338..30c8e2b45185 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -73,7 +73,7 @@ def run_llama_test(test_config): infer_engine.shard_model_by(shardformer) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) - outputs = infer_engine.generate(input_ids, generate_kwargs) + outputs = infer_engine.generate(input_ids, **generate_kwargs) #print("outputs.shape: ", outputs.shape) #print("outputs: ", outputs) diff --git a/tests/test_infer_ops/triton/test_bloom_context_attention.py b/tests/test_infer_ops/triton/test_bloom_context_attention.py index ea89d6bb4764..7447c85c5887 100644 --- a/tests/test_infer_ops/triton/test_bloom_context_attention.py +++ b/tests/test_infer_ops/triton/test_bloom_context_attention.py @@ -1,16 +1,17 @@ -import pytest import math -from packaging import version +import pytest import torch +from packaging import version from torch import nn from torch.nn import functional as F try: import triton import triton.language as tl - from tests.test_infer_ops.triton.utils import benchmark, torch_context_attention + from colossalai.kernel.triton import bloom_context_attn_fwd + from tests.test_infer_ops.triton.utils import torch_context_attention HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -18,33 +19,36 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, + reason="triton requires cuda version to be higher than 11.4") def test_bloom_context_attention(): bs = 4 head_num = 8 seq_len = 1024 head_dim = 64 - - query = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - k = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - v = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - + query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + max_input_len = seq_len - b_start = torch.zeros((bs, ), device="cuda", dtype=torch.int32) - b_len = torch.zeros((bs, ), device="cuda", dtype=torch.int32) - + b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32) + b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32) + for i in range(bs): b_start[i] = i * seq_len b_len[i] = seq_len - - o = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + + o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda") bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len, alibi) - + torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - - assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched" + + assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, + atol=1e-2), "outputs from triton and torch are not matched" + if __name__ == "__main__": - test_bloom_context_attention() \ No newline at end of file + test_bloom_context_attention() diff --git a/tests/test_infer_ops/triton/test_copy_kv_dest.py b/tests/test_infer_ops/triton/test_copy_kv_dest.py index 188493eb13ce..c656f81d2790 100644 --- a/tests/test_infer_ops/triton/test_copy_kv_dest.py +++ b/tests/test_infer_ops/triton/test_copy_kv_dest.py @@ -1,13 +1,12 @@ import pytest -from packaging import version - import torch +from packaging import version from torch import nn try: import triton import triton.language as tl - from tests.test_kernels.triton.utils import benchmark + from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest HAS_TRITON = True except ImportError: @@ -16,23 +15,25 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, + reason="triton requires cuda version to be higher than 11.4") def test_kv_cache_copy_op(): - + B_NTX = 32 * 2048 head_num = 8 head_dim = 64 - + cache = torch.randn((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16) dest_index = torch.arange(0, B_NTX, device="cuda", dtype=torch.int32) - + dest_data = torch.ones((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16) - + copy_kv_cache_to_dest(cache, dest_index, dest_data) - - assert torch.allclose(cache.cpu(), dest_data.cpu(), rtol=1e-3, atol=1e-3), "copy_kv_cache_to_dest outputs from triton and torch are not matched" - + + assert torch.allclose(cache.cpu(), dest_data.cpu(), rtol=1e-3, + atol=1e-3), "copy_kv_cache_to_dest outputs from triton and torch are not matched" + if __name__ == "__main__": test_kv_cache_copy_op() - diff --git a/tests/test_infer_ops/triton/test_layernorm_triton.py b/tests/test_infer_ops/triton/test_layernorm_triton.py index 9648f91e2f28..94cd704ffeba 100644 --- a/tests/test_infer_ops/triton/test_layernorm_triton.py +++ b/tests/test_infer_ops/triton/test_layernorm_triton.py @@ -4,7 +4,6 @@ from colossalai.kernel.triton import layer_norm from colossalai.testing.utils import parameterize -from tests.test_infer_ops.triton.utils import benchmark try: import triton diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py index 4c49c0b51333..1659fdde8f7f 100644 --- a/tests/test_infer_ops/triton/test_llama_context_attention.py +++ b/tests/test_infer_ops/triton/test_llama_context_attention.py @@ -1,16 +1,17 @@ -import pytest import math -from packaging import version +import pytest import torch +from packaging import version from torch import nn from torch.nn import functional as F try: import triton import triton.language as tl - from tests.test_infer_ops.triton.utils import benchmark, torch_context_attention + from colossalai.kernel.triton import llama_context_attn_fwd + from tests.test_infer_ops.triton.utils import torch_context_attention HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -19,32 +20,34 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, + reason="triton requires cuda version to be higher than 11.4") def test_llama_context_attention(): bs = 4 head_num = 8 seq_len = 1024 head_dim = 64 - - query = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - k = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - v = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - + query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + max_input_len = seq_len - b_start = torch.zeros((bs, ), device="cuda", dtype=torch.int32) - b_len = torch.zeros((bs, ), device="cuda", dtype=torch.int32) - + b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32) + b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32) + for i in range(bs): b_start[i] = i * seq_len b_len[i] = seq_len - - o = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + + o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len) - + torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - - assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3), "outputs from triton and torch are not matched" - + + assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, + atol=1e-3), "outputs from triton and torch are not matched" + + if __name__ == "__main__": - test_llama_context_attention() \ No newline at end of file + test_llama_context_attention() diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py index f9457c1a04f7..d5ecdf684538 100644 --- a/tests/test_infer_ops/triton/test_rotary_embedding.py +++ b/tests/test_infer_ops/triton/test_rotary_embedding.py @@ -11,7 +11,6 @@ import triton.language as tl from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd - from tests.test_infer_ops.triton.utils import benchmark HAS_TRITON = True except ImportError: @@ -50,7 +49,7 @@ def test_rotary_emb(): rotary_embedding_fwd(x, cos, sin) y_triton = x # compare - assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=1e-2) + assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=0) if __name__ == "__main__": diff --git a/tests/test_infer_ops/triton/test_token_attn_1.py b/tests/test_infer_ops/triton/test_token_attn_1.py index d01685e7788f..aee7944597dc 100644 --- a/tests/test_infer_ops/triton/test_token_attn_1.py +++ b/tests/test_infer_ops/triton/test_token_attn_1.py @@ -59,18 +59,7 @@ def test_attn_1(): kv_cache_seq_len[i] = seq_len b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") - # Warm up - for _ in range(10): - token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) - - run_iter = 1000 - torch.cuda.synchronize() - t1 = time.time() - for _ in range(run_iter): - token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) - torch.cuda.synchronize() - t2 = time.time() - print("Time cost {}".format((t2 - t1) / run_iter)) + token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze() o = attn_out.squeeze() @@ -78,5 +67,6 @@ def test_attn_1(): print("mean ", torch.mean(torch.abs(torch_out - o))) assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) + if __name__ == "__main__": test_attn_1() diff --git a/tests/test_infer_ops/triton/test_token_attn_2.py b/tests/test_infer_ops/triton/test_token_attn_2.py index 36b517c4aa3b..f834fedbb0f1 100644 --- a/tests/test_infer_ops/triton/test_token_attn_2.py +++ b/tests/test_infer_ops/triton/test_token_attn_2.py @@ -48,17 +48,8 @@ def test_token_attn_2(): kv_cache_seq_len[i] = seq_len kv_cache_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") - # Warm up - for _ in range(10): - token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) - run_iter = 1000 - torch.cuda.synchronize() - t1 = time.time() - for _ in range(run_iter): - token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) - torch.cuda.synchronize() - t2 = time.time() - print("Time cost {}".format((t2 - t1) / run_iter)) + token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) + torch_out = torch_attn(V, Prob, batch_size, seq_len, head_num, head_dim).squeeze() o = attn_out print("max ", torch.max(torch.abs(torch_out - o))) diff --git a/tests/test_infer_ops/triton/test_token_attn_fwd.py b/tests/test_infer_ops/triton/test_token_attn_fwd.py index e765ed4a3415..e82318965e05 100644 --- a/tests/test_infer_ops/triton/test_token_attn_fwd.py +++ b/tests/test_infer_ops/triton/test_token_attn_fwd.py @@ -56,18 +56,7 @@ def test(): kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda") token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi) - torch.cuda.synchronize() - start = time.time() - token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi) - torch.cuda.synchronize() - print("cost time:", (time.time() - start) * 1000) - - torch_att(q, k, v, Z, seq_len, head_num, head_dim) - torch.cuda.synchronize() - start = time.time() torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim) - torch.cuda.synchronize() - print("cost time:", (time.time() - start) * 1000) print("max ", torch.max(torch.abs(torch_out - o))) print("mean ", torch.mean(torch.abs(torch_out - o))) diff --git a/tests/test_infer_ops/triton/utils.py b/tests/test_infer_ops/triton/utils.py index 940d277cfb02..b081b32b9ad3 100644 --- a/tests/test_infer_ops/triton/utils.py +++ b/tests/test_infer_ops/triton/utils.py @@ -1,32 +1,10 @@ -import numpy as np import math +import numpy as np import torch from torch.nn import functional as F -def benchmark(func, *args): - starter, ender = torch.cuda.Event( - enable_timing=True), torch.cuda.Event(enable_timing=True) - repetitions = 300 - - for i in range(10): - func(*args) - - timings = np.zeros((repetitions, 1)) - with torch.no_grad(): - for rep in range(repetitions): - starter.record() - func(*args) - ender.record() - # WAIT FOR GPU SYNC - torch.cuda.synchronize() - curr_time = starter.elapsed_time(ender) - timings[rep] = curr_time - - mean_syn = np.sum(timings) / repetitions - return mean_syn - def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): ''' adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 @@ -42,9 +20,9 @@ def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): xq = xq.transpose(1, 2) keys = keys.transpose(1, 2) values = values.transpose(1, 2) - sm_scale = 1/math.sqrt(head_dim) + sm_scale = 1 / math.sqrt(head_dim) scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16) - + output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) - return output \ No newline at end of file + return output