Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add):

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)

lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int32"
)

with ib.if_scope(scan_axis_size == 0):
with ib.new_scope():
bx = te.thread_axis("blockIdx.x")
Expand Down Expand Up @@ -95,13 +99,11 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add):

# The following algorithm performs parallel exclusive scan
# Up Sweep of exclusive scan
lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int64"
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
with ib.for_range(0, lim, dtype="int32") as l2_width:
width = 2 << l2_width

with ib.new_scope():
width = tvm.tir.generic.cast(width, "int64")
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
Expand Down Expand Up @@ -136,10 +138,11 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add):
reduction[bx] = output[(bx + 1) * scan_axis_size - 1]
output[(bx + 1) * scan_axis_size - 1] = cast(0, out_dtype)

with ib.for_range(0, lim, dtype="int64") as l2_width:
with ib.for_range(0, lim, dtype="int32") as l2_width:
width = 2 << (lim - l2_width - 1)

with ib.new_scope():
width = tvm.tir.generic.cast(width, "int64")
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,13 @@ def mergesort(source, dest, source_idx, dest_idx, size, width, even):
bottom_up_merge(source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even)

lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64"
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int32"
)
with ib.for_range(0, lim, dtype="int64") as l2_width:
with ib.for_range(0, lim, dtype="int32") as l2_width:
width = 2 << l2_width
# Define and launch the cuda kernel
with ib.new_scope():
width = tvm.tir.generic.cast(width, "int64")
i = ib.allocate("int64", (1,), name="i", scope="local")
j = ib.allocate("int64", (1,), name="j", scope="local")
start = ib.allocate("int64", (1,), name="start", scope="local")
Expand Down
2 changes: 2 additions & 0 deletions src/target/spirv/intrin_rule_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp").set_body(DispatchGLSLPureIntri

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log").set_body(DispatchGLSLPureIntrin<GLSLstd450Log>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log2").set_body(DispatchGLSLPureIntrin<GLSLstd450Log2>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt").set_body(DispatchGLSLPureIntrin<GLSLstd450Sqrt>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow").set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>);
Expand Down
2 changes: 1 addition & 1 deletion tests/python/topi/python/test_topi_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def check_device(device):
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3)

for device in ["llvm", "cuda", "opencl"]:
for device in ["llvm", "cuda", "opencl", "vulkan"]:
check_device(device)


Expand Down