From 104cc54c0c220e93640cbad8501ceb16ee158acc Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 15 Feb 2021 04:12:45 +0900 Subject: [PATCH 1/2] Avoid passing int64 scalar to vulkan runtime --- python/tvm/topi/cuda/sort.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index c0f076fb6065..41f1eda9acbd 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -198,12 +198,13 @@ def mergesort(source, dest, source_idx, dest_idx, size, width, even): bottom_up_merge(source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even) lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64" + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int32" ) - with ib.for_range(0, lim, dtype="int64") as l2_width: + with ib.for_range(0, lim, dtype="int32") as l2_width: width = 2 << l2_width # Define and launch the cuda kernel with ib.new_scope(): + width = tvm.tir.generic.cast(width, "int64") i = ib.allocate("int64", (1,), name="i", scope="local") j = ib.allocate("int64", (1,), name="j", scope="local") start = ib.allocate("int64", (1,), name="start", scope="local") From 00e503ad3386100b1b8f20a7403660a6392b2fe8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 15 Feb 2021 04:38:29 +0900 Subject: [PATCH 2/2] get_valid_count works on vulkan but log2 is done by spirv not host --- python/tvm/topi/cuda/scan.py | 13 ++++++++----- src/target/spirv/intrin_rule_spirv.cc | 2 ++ tests/python/topi/python/test_topi_vision.py | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 0bdab100b429..5d3ce9b3c5e7 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -67,6 +67,10 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add): max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + lim = tvm.tir.generic.cast( + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int32" + ) + with ib.if_scope(scan_axis_size == 0): with ib.new_scope(): bx = te.thread_axis("blockIdx.x") @@ -95,13 +99,11 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add): # The following algorithm performs parallel exclusive scan # Up Sweep of exclusive scan - lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int64" - ) - with ib.for_range(0, lim, dtype="int64") as l2_width: + with ib.for_range(0, lim, dtype="int32") as l2_width: width = 2 << l2_width with ib.new_scope(): + width = tvm.tir.generic.cast(width, "int64") tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) @@ -136,10 +138,11 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add): reduction[bx] = output[(bx + 1) * scan_axis_size - 1] output[(bx + 1) * scan_axis_size - 1] = cast(0, out_dtype) - with ib.for_range(0, lim, dtype="int64") as l2_width: + with ib.for_range(0, lim, dtype="int32") as l2_width: width = 2 << (lim - l2_width - 1) with ib.new_scope(): + width = tvm.tir.generic.cast(width, "int64") tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index 90b2eb2a671f..916662abf37d 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -64,6 +64,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp").set_body(DispatchGLSLPureIntri TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log").set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log2").set_body(DispatchGLSLPureIntrin); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt").set_body(DispatchGLSLPureIntrin); TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow").set_body(DispatchGLSLPureIntrin); diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py index 697ef8a24f67..a2eb0f65b5d0 100644 --- a/tests/python/topi/python/test_topi_vision.py +++ b/tests/python/topi/python/test_topi_vision.py @@ -112,7 +112,7 @@ def check_device(device): tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3) - for device in ["llvm", "cuda", "opencl"]: + for device in ["llvm", "cuda", "opencl", "vulkan"]: check_device(device)