Skip to content

[Bug] SplitHostDevice generates free var when var only exists in T.thread_binding of device function #16237

@jinhongyii

Description

@jinhongyii

Expected behavior

SplitHostDevice generates no free var.

Actual behavior

SplitHostDevice generates free var cse_var_3 using the script below


    @T.prim_func(private=True)
    def default_function_kernel_2(output_buf: T.handle("int32", "global"), cse_var_4: T.int64, i: T.int32, seq_len: T.int32):
        T.func_attr({"target": T.target({"keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 256, "tag": "", "thread_warp_size": 16}), "tir.is_global_func": T.bool(True), "tir.noalias": T.bool(True)})
        output_buf_1 = T.decl_buffer((T.Cast("int64", seq_len) * T.int64(8),), "int32", data=output_buf, align=8)
        end_1 = T.handle("int64", "local")
        end_1_1 = T.decl_buffer((1,), "int64", data=end_1, scope="local")
        middle_1 = T.handle("int64", "local")
        middle_1_1 = T.decl_buffer((1,), "int64", data=middle_1, scope="local")
        start_1 = T.handle("int64", "local")
        start_1_1 = T.decl_buffer((1,), "int64", data=start_1, scope="local")
        threadIdx_x = T.launch_thread("threadIdx.x", 256)
        start_1 = T.allocate([1], "int64", "local")
        middle_1 = T.allocate([1], "int64", "local")
        end_1 = T.allocate([1], "int64", "local")
        cse_var_3 = T.int32()
        blockIdx_x = T.launch_thread("blockIdx.x", (cse_var_3 - 1) // (T.shift_left(2, i) * 256) + 1)
        blockIdx_y = T.launch_thread("blockIdx.y", 1)
        start_1_1[0] = T.Cast("int64", T.shift_left(2, i)) * (T.Cast("int64", blockIdx_x) * T.int64(256) + T.Cast("int64", threadIdx_x))
        if start_1_1[0] < cse_var_4:
            middle_1_1[0] = T.Cast("int64", T.shift_left(2, i)) // T.int64(2) + start_1_1[0]
            end_1_1[0] = T.min(start_1_1[0] + T.Cast("int64", T.shift_left(2, i)), cse_var_4)
            if middle_1_1[0] < cse_var_4:
                output_buf_1[end_1_1[0] - T.int64(1)] = output_buf_1[end_1_1[0] - T.int64(1)] + output_buf_1[middle_1_1[0] - T.int64(1)]

Environment

TVM Unity branch

Steps to reproduce

from tvm.script import tir as T
import tvm

@T.prim_func(private=True)
def cumsum(var_A: T.handle, var_T_add: T.handle, seq_len: T.int32):
    T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
    A = T.match_buffer(var_A, (seq_len * 8,), "int32")
    T_add = T.match_buffer(var_T_add, (seq_len * 8,), "int32")
    # with T.block("root"):
    T_expand_dims = T.alloc_buffer((1, seq_len * 8), "int32")
    output_buf = T.alloc_buffer((1, seq_len * 8), "int32", align=8)
    T_squeeze = T.alloc_buffer((seq_len * 8,), "int32")
    for ax0_fused_0 in T.thread_binding((seq_len * 8 + 1023) // 1024, thread="blockIdx.x"):
        for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
            with T.block("T_expand_dims"):
                v0 = T.axis.spatial(seq_len * 8, ax0_fused_0 * 1024 + ax0_fused_1)
                T.where(ax0_fused_0 * 1024 + ax0_fused_1 < seq_len * 8)
                T.reads(A[v0])
                T.writes(T_expand_dims[0, v0])
                T_expand_dims[0, v0] = A[v0]
    with T.block("exclusive_scan"):
        T.reads(T_expand_dims[0, 0:seq_len * 8])
        T.writes(output_buf[0, 0:seq_len * 8])
        if seq_len * 8 == 0:
            blockIdx_x = T.launch_thread("blockIdx.x", 1)
            if blockIdx_x < 1:
                T.evaluate(0)
        else:
            with T.launch_thread("threadIdx.x", 256) as threadIdx_x:
                blockIdx_x = T.launch_thread("blockIdx.x", T.max(1, (seq_len * 8 + 255) // 256))
                blockIdx_y = T.launch_thread("blockIdx.y", 1)
                if blockIdx_x * 256 + threadIdx_x < seq_len * 8:
                    output_buf[(blockIdx_y * (seq_len * 8) + (blockIdx_x * 256 + threadIdx_x)) // (seq_len * 8), (blockIdx_y * (seq_len * 8) + (blockIdx_x * 256 + threadIdx_x)) % (seq_len * 8)] = T_expand_dims[(blockIdx_y * (seq_len * 8) + (blockIdx_x * 256 + threadIdx_x)) // (seq_len * 8), (blockIdx_y * (seq_len * 8) + (blockIdx_x * 256 + threadIdx_x)) % (seq_len * 8)]
            for i in range(T.Cast("int32", T.ceil(T.log2(T.Cast("float32", seq_len * 8))))):
                threadIdx_x = T.launch_thread("threadIdx.x", 256)
                blockIdx_x = T.launch_thread("blockIdx.x", T.max(1, T.Cast("int32", (seq_len * 8 + (256 * T.shift_left(2, i) - 1)) // (256 * T.shift_left(2, i)))))
                blockIdx_y = T.launch_thread("blockIdx.y", 1)
                start = T.allocate([1], "int64", "local")
                middle = T.allocate([1], "int64", "local")
                end = T.allocate([1], "int64", "local")
                start_1 = T.Buffer((1,), "int64", data=start, scope="local")
                start_1[0] = T.Cast("int64", T.shift_left(2, i)) * T.Cast("int64", blockIdx_x * 256 + threadIdx_x)
                if start_1[0] < T.Cast("int64", seq_len * 8):
                    middle_1 = T.Buffer((1,), "int64", data=middle, scope="local")
                    middle_1[0] = start_1[0] + T.Cast("int64", T.shift_left(2, i) // 2)
                    end_1 = T.Buffer((1,), "int64", data=end, scope="local")
                    end_1[0] = T.min(start_1[0] + T.Cast("int64", T.shift_left(2, i)), T.Cast("int64", seq_len * 8))
                    if middle_1[0] < T.Cast("int64", seq_len * 8):
                        output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_1[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_1[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)] = output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_1[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_1[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)] + output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)]
            with T.launch_thread("blockIdx.x", 1) as blockIdx_x:
                if blockIdx_x < 1:
                    output_buf[((blockIdx_x + 1) * (seq_len * 8) - 1) // (seq_len * 8), ((blockIdx_x + 1) * (seq_len * 8) - 1) % (seq_len * 8)] = 0
            for j in range(T.Cast("int32", T.ceil(T.log2(T.Cast("float32", seq_len * 8))))):
                threadIdx_x = T.launch_thread("threadIdx.x", 256)
                blockIdx_x = T.launch_thread("blockIdx.x", T.max(1, T.Cast("int32", (T.Cast("int64", seq_len * 8) + (T.int64(256) * T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float32", seq_len * 8)))) - T.Cast("int64", j) - T.int64(1)) - T.int64(1))) // (T.int64(256) * T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float32", seq_len * 8)))) - T.Cast("int64", j) - T.int64(1))))))
                blockIdx_y = T.launch_thread("blockIdx.y", 1)
                start = T.allocate([1], "int64", "local")
                middle = T.allocate([1], "int64", "local")
                end = T.allocate([1], "int64", "local")
                end_1 = T.allocate([1], "int32", "local")
                start_1 = T.Buffer((1,), "int64", data=start, scope="local")
                start_1[0] = T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float32", seq_len * 8)))) - T.Cast("int64", j) - T.int64(1)) * T.Cast("int64", blockIdx_x * 256 + threadIdx_x)
                if start_1[0] < T.Cast("int64", seq_len * 8):
                    middle_1 = T.Buffer((1,), "int64", data=middle, scope="local")
                    middle_1[0] = start_1[0] + T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float32", seq_len * 8)))) - T.Cast("int64", j) - T.int64(1)) // T.int64(2)
                    end_2 = T.Buffer((1,), "int64", data=end, scope="local")
                    end_2[0] = T.min(start_1[0] + T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float32", seq_len * 8)))) - T.Cast("int64", j) - T.int64(1)), T.Cast("int64", seq_len * 8))
                    if middle_1[0] < T.Cast("int64", seq_len * 8):
                        end_3 = T.Buffer((1,), "int32", data=end_1, scope="local")
                        end_3[0] = output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)]
                        output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)] = output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)]
                        output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)] = output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)] + end_3[0]
    for ax0_fused_0 in T.thread_binding((seq_len * 8 + 1023) // 1024, thread="blockIdx.x"):
        for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
            with T.block("T_squeeze"):
                v0 = T.axis.spatial(seq_len * 8, ax0_fused_0 * 1024 + ax0_fused_1)
                T.where(ax0_fused_0 * 1024 + ax0_fused_1 < seq_len * 8)
                T.reads(output_buf[0, v0])
                T.writes(T_squeeze[v0])
                T_squeeze[v0] = output_buf[0, v0]
    for ax0_fused_0 in T.thread_binding((seq_len * 8 + 1023) // 1024, thread="blockIdx.x"):
        for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
            with T.block("T_add"):
                v0 = T.axis.spatial(seq_len * 8, ax0_fused_0 * 1024 + ax0_fused_1)
                T.where(ax0_fused_0 * 1024 + ax0_fused_1 < seq_len * 8)
                T.reads(A[v0], T_squeeze[v0])
                T.writes(T_add[v0])
                T_add[v0] = A[v0] + T_squeeze[v0]
                
tvm.build(cumsum, target="metal")

cc: @Lunderberg

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions