Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_infer/test_bloom_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def run():
infer_engine.shard_model_by(shardformer)

generate_kwargs = dict(do_sample=False)
outputs = infer_engine.generate(input_ids, generate_kwargs)
outputs = infer_engine.generate(input_ids, **generate_kwargs)

if not dist.is_initialized() or dist.get_rank() == 0:
output_text = tokenizer.decode(outputs[0])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_infer/test_llama_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def run_llama_test(test_config):
infer_engine.shard_model_by(shardformer)

generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
outputs = infer_engine.generate(input_ids, generate_kwargs)
outputs = infer_engine.generate(input_ids, **generate_kwargs)
#print("outputs.shape: ", outputs.shape)

#print("outputs: ", outputs)
Expand Down
40 changes: 22 additions & 18 deletions tests/test_infer_ops/triton/test_bloom_context_attention.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,54 @@
import pytest
import math
from packaging import version

import pytest
import torch
from packaging import version
from torch import nn
from torch.nn import functional as F

try:
import triton
import triton.language as tl
from tests.test_infer_ops.triton.utils import benchmark, torch_context_attention

from colossalai.kernel.triton import bloom_context_attn_fwd
from tests.test_infer_ops.triton.utils import torch_context_attention
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")

TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')

@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")

@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
reason="triton requires cuda version to be higher than 11.4")
def test_bloom_context_attention():
bs = 4
head_num = 8
seq_len = 1024
head_dim = 64

query = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
k = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
v = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")


query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")

max_input_len = seq_len
b_start = torch.zeros((bs, ), device="cuda", dtype=torch.int32)
b_len = torch.zeros((bs, ), device="cuda", dtype=torch.int32)
b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32)
b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32)

for i in range(bs):
b_start[i] = i * seq_len
b_len[i] = seq_len
o = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")

o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda")
bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len, alibi)

torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)

assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched"

assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3,
atol=1e-2), "outputs from triton and torch are not matched"


if __name__ == "__main__":
test_bloom_context_attention()
test_bloom_context_attention()
25 changes: 13 additions & 12 deletions tests/test_infer_ops/triton/test_copy_kv_dest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import pytest
from packaging import version

import torch
from packaging import version
from torch import nn

try:
import triton
import triton.language as tl
from tests.test_kernels.triton.utils import benchmark

from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
HAS_TRITON = True
except ImportError:
Expand All @@ -16,23 +15,25 @@

TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')

@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")

@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
reason="triton requires cuda version to be higher than 11.4")
def test_kv_cache_copy_op():

B_NTX = 32 * 2048
head_num = 8
head_dim = 64

cache = torch.randn((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)
dest_index = torch.arange(0, B_NTX, device="cuda", dtype=torch.int32)

dest_data = torch.ones((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16)

copy_kv_cache_to_dest(cache, dest_index, dest_data)

assert torch.allclose(cache.cpu(), dest_data.cpu(), rtol=1e-3, atol=1e-3), "copy_kv_cache_to_dest outputs from triton and torch are not matched"


assert torch.allclose(cache.cpu(), dest_data.cpu(), rtol=1e-3,
atol=1e-3), "copy_kv_cache_to_dest outputs from triton and torch are not matched"


if __name__ == "__main__":
test_kv_cache_copy_op()

1 change: 0 additions & 1 deletion tests/test_infer_ops/triton/test_layernorm_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from colossalai.kernel.triton import layer_norm
from colossalai.testing.utils import parameterize
from tests.test_infer_ops.triton.utils import benchmark

try:
import triton
Expand Down
41 changes: 22 additions & 19 deletions tests/test_infer_ops/triton/test_llama_context_attention.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import pytest
import math
from packaging import version

import pytest
import torch
from packaging import version
from torch import nn
from torch.nn import functional as F

try:
import triton
import triton.language as tl
from tests.test_infer_ops.triton.utils import benchmark, torch_context_attention

from colossalai.kernel.triton import llama_context_attn_fwd
from tests.test_infer_ops.triton.utils import torch_context_attention
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
Expand All @@ -19,32 +20,34 @@
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')


@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4")
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
reason="triton requires cuda version to be higher than 11.4")
def test_llama_context_attention():
bs = 4
head_num = 8
seq_len = 1024
head_dim = 64

query = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
k = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
v = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")


query = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
k = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
v = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")

max_input_len = seq_len
b_start = torch.zeros((bs, ), device="cuda", dtype=torch.int32)
b_len = torch.zeros((bs, ), device="cuda", dtype=torch.int32)
b_start = torch.zeros((bs,), device="cuda", dtype=torch.int32)
b_len = torch.zeros((bs,), device="cuda", dtype=torch.int32)

for i in range(bs):
b_start[i] = i * seq_len
b_len[i] = seq_len
o = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")

o = torch.randn((bs * seq_len, head_num, head_dim), dtype=torch.float16, device="cuda")
llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len)

torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)

assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3), "outputs from triton and torch are not matched"


assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3,
atol=1e-3), "outputs from triton and torch are not matched"


if __name__ == "__main__":
test_llama_context_attention()
test_llama_context_attention()
3 changes: 1 addition & 2 deletions tests/test_infer_ops/triton/test_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import triton.language as tl

from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd
from tests.test_infer_ops.triton.utils import benchmark

HAS_TRITON = True
except ImportError:
Expand Down Expand Up @@ -50,7 +49,7 @@ def test_rotary_emb():
rotary_embedding_fwd(x, cos, sin)
y_triton = x
# compare
assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=1e-2)
assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=0)


if __name__ == "__main__":
Expand Down
14 changes: 2 additions & 12 deletions tests/test_infer_ops/triton/test_token_attn_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,14 @@ def test_attn_1():
kv_cache_seq_len[i] = seq_len
b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda")

# Warm up
for _ in range(10):
token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)

run_iter = 1000
torch.cuda.synchronize()
t1 = time.time()
for _ in range(run_iter):
token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)
torch.cuda.synchronize()
t2 = time.time()
print("Time cost {}".format((t2 - t1) / run_iter))
token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)

torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze()
o = attn_out.squeeze()
print("max ", torch.max(torch.abs(torch_out - o)))
print("mean ", torch.mean(torch.abs(torch_out - o)))
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)


if __name__ == "__main__":
test_attn_1()
13 changes: 2 additions & 11 deletions tests/test_infer_ops/triton/test_token_attn_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,8 @@ def test_token_attn_2():
kv_cache_seq_len[i] = seq_len
kv_cache_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda")

# Warm up
for _ in range(10):
token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)
run_iter = 1000
torch.cuda.synchronize()
t1 = time.time()
for _ in range(run_iter):
token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)
torch.cuda.synchronize()
t2 = time.time()
print("Time cost {}".format((t2 - t1) / run_iter))
token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)

torch_out = torch_attn(V, Prob, batch_size, seq_len, head_num, head_dim).squeeze()
o = attn_out
print("max ", torch.max(torch.abs(torch_out - o)))
Expand Down
11 changes: 0 additions & 11 deletions tests/test_infer_ops/triton/test_token_attn_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,7 @@ def test():
kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda")

token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi)
torch.cuda.synchronize()
start = time.time()
token_attention_fwd(q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, alibi=alibi)
torch.cuda.synchronize()
print("cost time:", (time.time() - start) * 1000)

torch_att(q, k, v, Z, seq_len, head_num, head_dim)
torch.cuda.synchronize()
start = time.time()
torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim)
torch.cuda.synchronize()
print("cost time:", (time.time() - start) * 1000)

print("max ", torch.max(torch.abs(torch_out - o)))
print("mean ", torch.mean(torch.abs(torch_out - o)))
Expand Down
30 changes: 4 additions & 26 deletions tests/test_infer_ops/triton/utils.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,10 @@
import numpy as np
import math

import numpy as np
import torch
from torch.nn import functional as F


def benchmark(func, *args):
starter, ender = torch.cuda.Event(
enable_timing=True), torch.cuda.Event(enable_timing=True)
repetitions = 300

for i in range(10):
func(*args)

timings = np.zeros((repetitions, 1))
with torch.no_grad():
for rep in range(repetitions):
starter.record()
func(*args)
ender.record()
# WAIT FOR GPU SYNC
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender)
timings[rep] = curr_time

mean_syn = np.sum(timings) / repetitions
return mean_syn

def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
'''
adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
Expand All @@ -42,9 +20,9 @@ def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
sm_scale = 1/math.sqrt(head_dim)
sm_scale = 1 / math.sqrt(head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale
scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16)

output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)
return output
return output