-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed
Description
Pass lower_warp_memory cannot handle more than one warp buffers. Buffers except the first one cannot be correctly transformed to warp shuffles.
To reproduce:
import tvm [8/1976]
import topi
import numpy as np
from tvm import te
dtype = "float32"
target = "cuda"
m = 32
A = te.placeholder((m,), name='A', dtype=dtype)
B = te.placeholder((m,), name='B', dtype=dtype)
C = te.compute((m,), lambda i: A[(i + 1) % m] + B[(i + 1) % m], name='C')
cuda_target = tvm.target.create("cuda")
assert m <= cuda_target.thread_warp_size
with cuda_target:
s = te.create_schedule(C.op)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
AA = s.cache_read(A, "warp", [C])
BB = s.cache_read(B, "warp", [C])
xo, xi = s[C].split(C.op.axis[0], nparts=1)
s[C].bind(xi, tx)
s[C].bind(xo, bx)
s[AA].compute_at(s[C], xo)
s[BB].compute_at(s[C], xo)
xo, xi = s[AA].split(s[AA].op.axis[0], nparts=1)
s[AA].bind(xo, bx)
s[AA].bind(xi, tx)
xo, xi = s[BB].split(s[BB].op.axis[0], nparts=1)
s[BB].bind(xo, bx)
s[BB].bind(xi, tx)
print(tvm.lower(s, [A, B, C], target, simple_mode=True))
compute = tvm.build(s, [A, B, C], target, name="run")
print(compute.imported_modules[0].get_source())I think the problem is WarpMemoryRewriter::VisitStmt_(const AllocateNode*) in lower_warp_memory.cc doesn't continue the recursion after rewriting the first buffer.
I will fix it.
Metadata
Metadata
Assignees
Labels
No labels