-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Expected behavior
SplitHostDevice generates no free var.
Actual behavior
SplitHostDevice generates free var cse_var_3 using the script below
@T.prim_func(private=True)
def default_function_kernel_2(output_buf: T.handle("int32", "global"), cse_var_4: T.int64, i: T.int32, seq_len: T.int32):
T.func_attr({"target": T.target({"keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 256, "tag": "", "thread_warp_size": 16}), "tir.is_global_func": T.bool(True), "tir.noalias": T.bool(True)})
output_buf_1 = T.decl_buffer((T.Cast("int64", seq_len) * T.int64(8),), "int32", data=output_buf, align=8)
end_1 = T.handle("int64", "local")
end_1_1 = T.decl_buffer((1,), "int64", data=end_1, scope="local")
middle_1 = T.handle("int64", "local")
middle_1_1 = T.decl_buffer((1,), "int64", data=middle_1, scope="local")
start_1 = T.handle("int64", "local")
start_1_1 = T.decl_buffer((1,), "int64", data=start_1, scope="local")
threadIdx_x = T.launch_thread("threadIdx.x", 256)
start_1 = T.allocate([1], "int64", "local")
middle_1 = T.allocate([1], "int64", "local")
end_1 = T.allocate([1], "int64", "local")
cse_var_3 = T.int32()
blockIdx_x = T.launch_thread("blockIdx.x", (cse_var_3 - 1) // (T.shift_left(2, i) * 256) + 1)
blockIdx_y = T.launch_thread("blockIdx.y", 1)
start_1_1[0] = T.Cast("int64", T.shift_left(2, i)) * (T.Cast("int64", blockIdx_x) * T.int64(256) + T.Cast("int64", threadIdx_x))
if start_1_1[0] < cse_var_4:
middle_1_1[0] = T.Cast("int64", T.shift_left(2, i)) // T.int64(2) + start_1_1[0]
end_1_1[0] = T.min(start_1_1[0] + T.Cast("int64", T.shift_left(2, i)), cse_var_4)
if middle_1_1[0] < cse_var_4:
output_buf_1[end_1_1[0] - T.int64(1)] = output_buf_1[end_1_1[0] - T.int64(1)] + output_buf_1[middle_1_1[0] - T.int64(1)]
Environment
TVM Unity branch
Steps to reproduce
from tvm.script import tir as T
import tvm
@T.prim_func(private=True)
def cumsum(var_A: T.handle, var_T_add: T.handle, seq_len: T.int32):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
A = T.match_buffer(var_A, (seq_len * 8,), "int32")
T_add = T.match_buffer(var_T_add, (seq_len * 8,), "int32")
# with T.block("root"):
T_expand_dims = T.alloc_buffer((1, seq_len * 8), "int32")
output_buf = T.alloc_buffer((1, seq_len * 8), "int32", align=8)
T_squeeze = T.alloc_buffer((seq_len * 8,), "int32")
for ax0_fused_0 in T.thread_binding((seq_len * 8 + 1023) // 1024, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block("T_expand_dims"):
v0 = T.axis.spatial(seq_len * 8, ax0_fused_0 * 1024 + ax0_fused_1)
T.where(ax0_fused_0 * 1024 + ax0_fused_1 < seq_len * 8)
T.reads(A[v0])
T.writes(T_expand_dims[0, v0])
T_expand_dims[0, v0] = A[v0]
with T.block("exclusive_scan"):
T.reads(T_expand_dims[0, 0:seq_len * 8])
T.writes(output_buf[0, 0:seq_len * 8])
if seq_len * 8 == 0:
blockIdx_x = T.launch_thread("blockIdx.x", 1)
if blockIdx_x < 1:
T.evaluate(0)
else:
with T.launch_thread("threadIdx.x", 256) as threadIdx_x:
blockIdx_x = T.launch_thread("blockIdx.x", T.max(1, (seq_len * 8 + 255) // 256))
blockIdx_y = T.launch_thread("blockIdx.y", 1)
if blockIdx_x * 256 + threadIdx_x < seq_len * 8:
output_buf[(blockIdx_y * (seq_len * 8) + (blockIdx_x * 256 + threadIdx_x)) // (seq_len * 8), (blockIdx_y * (seq_len * 8) + (blockIdx_x * 256 + threadIdx_x)) % (seq_len * 8)] = T_expand_dims[(blockIdx_y * (seq_len * 8) + (blockIdx_x * 256 + threadIdx_x)) // (seq_len * 8), (blockIdx_y * (seq_len * 8) + (blockIdx_x * 256 + threadIdx_x)) % (seq_len * 8)]
for i in range(T.Cast("int32", T.ceil(T.log2(T.Cast("float32", seq_len * 8))))):
threadIdx_x = T.launch_thread("threadIdx.x", 256)
blockIdx_x = T.launch_thread("blockIdx.x", T.max(1, T.Cast("int32", (seq_len * 8 + (256 * T.shift_left(2, i) - 1)) // (256 * T.shift_left(2, i)))))
blockIdx_y = T.launch_thread("blockIdx.y", 1)
start = T.allocate([1], "int64", "local")
middle = T.allocate([1], "int64", "local")
end = T.allocate([1], "int64", "local")
start_1 = T.Buffer((1,), "int64", data=start, scope="local")
start_1[0] = T.Cast("int64", T.shift_left(2, i)) * T.Cast("int64", blockIdx_x * 256 + threadIdx_x)
if start_1[0] < T.Cast("int64", seq_len * 8):
middle_1 = T.Buffer((1,), "int64", data=middle, scope="local")
middle_1[0] = start_1[0] + T.Cast("int64", T.shift_left(2, i) // 2)
end_1 = T.Buffer((1,), "int64", data=end, scope="local")
end_1[0] = T.min(start_1[0] + T.Cast("int64", T.shift_left(2, i)), T.Cast("int64", seq_len * 8))
if middle_1[0] < T.Cast("int64", seq_len * 8):
output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_1[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_1[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)] = output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_1[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_1[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)] + output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)]
with T.launch_thread("blockIdx.x", 1) as blockIdx_x:
if blockIdx_x < 1:
output_buf[((blockIdx_x + 1) * (seq_len * 8) - 1) // (seq_len * 8), ((blockIdx_x + 1) * (seq_len * 8) - 1) % (seq_len * 8)] = 0
for j in range(T.Cast("int32", T.ceil(T.log2(T.Cast("float32", seq_len * 8))))):
threadIdx_x = T.launch_thread("threadIdx.x", 256)
blockIdx_x = T.launch_thread("blockIdx.x", T.max(1, T.Cast("int32", (T.Cast("int64", seq_len * 8) + (T.int64(256) * T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float32", seq_len * 8)))) - T.Cast("int64", j) - T.int64(1)) - T.int64(1))) // (T.int64(256) * T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float32", seq_len * 8)))) - T.Cast("int64", j) - T.int64(1))))))
blockIdx_y = T.launch_thread("blockIdx.y", 1)
start = T.allocate([1], "int64", "local")
middle = T.allocate([1], "int64", "local")
end = T.allocate([1], "int64", "local")
end_1 = T.allocate([1], "int32", "local")
start_1 = T.Buffer((1,), "int64", data=start, scope="local")
start_1[0] = T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float32", seq_len * 8)))) - T.Cast("int64", j) - T.int64(1)) * T.Cast("int64", blockIdx_x * 256 + threadIdx_x)
if start_1[0] < T.Cast("int64", seq_len * 8):
middle_1 = T.Buffer((1,), "int64", data=middle, scope="local")
middle_1[0] = start_1[0] + T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float32", seq_len * 8)))) - T.Cast("int64", j) - T.int64(1)) // T.int64(2)
end_2 = T.Buffer((1,), "int64", data=end, scope="local")
end_2[0] = T.min(start_1[0] + T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float32", seq_len * 8)))) - T.Cast("int64", j) - T.int64(1)), T.Cast("int64", seq_len * 8))
if middle_1[0] < T.Cast("int64", seq_len * 8):
end_3 = T.Buffer((1,), "int32", data=end_1, scope="local")
end_3[0] = output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)]
output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)] = output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)]
output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)] = output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)] + end_3[0]
for ax0_fused_0 in T.thread_binding((seq_len * 8 + 1023) // 1024, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block("T_squeeze"):
v0 = T.axis.spatial(seq_len * 8, ax0_fused_0 * 1024 + ax0_fused_1)
T.where(ax0_fused_0 * 1024 + ax0_fused_1 < seq_len * 8)
T.reads(output_buf[0, v0])
T.writes(T_squeeze[v0])
T_squeeze[v0] = output_buf[0, v0]
for ax0_fused_0 in T.thread_binding((seq_len * 8 + 1023) // 1024, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
with T.block("T_add"):
v0 = T.axis.spatial(seq_len * 8, ax0_fused_0 * 1024 + ax0_fused_1)
T.where(ax0_fused_0 * 1024 + ax0_fused_1 < seq_len * 8)
T.reads(A[v0], T_squeeze[v0])
T.writes(T_add[v0])
T_add[v0] = A[v0] + T_squeeze[v0]
tvm.build(cumsum, target="metal")
cc: @Lunderberg
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug