From e5d196a0a8acd5ef3a288f7b281f4090fb40f952 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 16 Mar 2021 15:29:35 +0900 Subject: [PATCH 1/8] [TOPI] Cast to float32 before log2 in sort/scan --- python/tvm/topi/cuda/scan.py | 2 +- python/tvm/topi/cuda/sort.py | 4 ++-- tests/python/unittest/test_target_codegen_spirv.py | 10 ++++++++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 5d3798e3d27b..2d9533b61ddd 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -104,7 +104,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i # The following algorithm performs parallel exclusive scan # Up Sweep of exclusive scan lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int64" + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float32"))), "int64" ) with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << l2_width diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 93e4d3feccc7..257691c0466f 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -239,7 +239,7 @@ def compare(a, b): # Sort the lower levels of the merge using odd-even sort, it's fast for small inputs lower_lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, "float64"))), "int64" + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, "float32"))), "int64" ) _odd_even_sort( @@ -255,7 +255,7 @@ def compare(a, b): ) upper_lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64" + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float32"))), "int64" ) def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count): diff --git a/tests/python/unittest/test_target_codegen_spirv.py b/tests/python/unittest/test_target_codegen_spirv.py index df42eeb721ab..b9f07cf426fe 100644 --- a/tests/python/unittest/test_target_codegen_spirv.py +++ b/tests/python/unittest/test_target_codegen_spirv.py @@ -104,6 +104,16 @@ def test_pushconstants(): check_mod(mod, x_np, res_np) + # One 64 bit and one 32 bit constants + dtype = "int32" + x = relay.var("x", shape=(relay.Any(),), dtype=dtype) + mod = tvm.IRModule() + mod["main"] = relay.Function([x], relay.cumsum(x)) + x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype) + res_np = np.cumsum(x_np) + + check_mod(mod, x_np, res_np) + def test_unique(): if not tvm.testing.device_enabled("vulkan"): From c7015956e7b3ed4f705f2152df105aa79226b190 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 17 Mar 2021 04:01:48 +0900 Subject: [PATCH 2/8] revert sort change since this seems unnecessary --- python/tvm/topi/cuda/sort.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 257691c0466f..93e4d3feccc7 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -239,7 +239,7 @@ def compare(a, b): # Sort the lower levels of the merge using odd-even sort, it's fast for small inputs lower_lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, "float32"))), "int64" + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, "float64"))), "int64" ) _odd_even_sort( @@ -255,7 +255,7 @@ def compare(a, b): ) upper_lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float32"))), "int64" + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64" ) def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count): From 1bdd8926ea926f54e4db4b9f9dded24689c9210b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 17 Mar 2021 04:19:01 +0900 Subject: [PATCH 3/8] only does cast to float32 on vk + dynamic input case --- python/tvm/topi/cuda/scan.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 2d9533b61ddd..f98f42e3264a 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name, too-many-locals, too-many-statements "Scan related operators" from typing import Callable, Optional, Union +import logging import tvm from tvm import te @@ -101,10 +102,24 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i nthread_bx = ceil_div(scan_axis_size, max_threads) nthread_by = batch_size + size_cast_dtype = "float64" + target = tvm.target.Target.current() + + if "vulkan" in str(target) and isinstance(scan_axis_size, tvm.tir.expr.Var): + # SPIRV seems to have an issue with float64 intrinsic + # TODO(masahi): Eliminate this concern by adding TIR level CSE + msg = """ + Casting the dynamic input size to float32 before computing log2 in exclusive scan. + This could result in a wrong output if the runtime value of the input size is very large. + """ + logging.warning(msg) + size_cast_dtype = "float32" + # The following algorithm performs parallel exclusive scan # Up Sweep of exclusive scan lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float32"))), "int64" + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, size_cast_dtype))), + "int64", ) with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << l2_width From 10a4078a63a6d861c1ade5f1aa59e58aacbfb22e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 17 Mar 2021 04:20:36 +0900 Subject: [PATCH 4/8] check against IntImm instead of Var --- python/tvm/topi/cuda/scan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index f98f42e3264a..abe929672df4 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -105,7 +105,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i size_cast_dtype = "float64" target = tvm.target.Target.current() - if "vulkan" in str(target) and isinstance(scan_axis_size, tvm.tir.expr.Var): + if "vulkan" in str(target) and not isinstance(scan_axis_size, tvm.tir.expr.IntImm): # SPIRV seems to have an issue with float64 intrinsic # TODO(masahi): Eliminate this concern by adding TIR level CSE msg = """ From aabc763e8992e7abacfb25a8bff47e8e2596b435 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 10 Apr 2021 18:53:55 +0900 Subject: [PATCH 5/8] revert change --- python/tvm/topi/cuda/scan.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index abe929672df4..5d3798e3d27b 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -17,7 +17,6 @@ # pylint: disable=invalid-name, too-many-locals, too-many-statements "Scan related operators" from typing import Callable, Optional, Union -import logging import tvm from tvm import te @@ -102,24 +101,10 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i nthread_bx = ceil_div(scan_axis_size, max_threads) nthread_by = batch_size - size_cast_dtype = "float64" - target = tvm.target.Target.current() - - if "vulkan" in str(target) and not isinstance(scan_axis_size, tvm.tir.expr.IntImm): - # SPIRV seems to have an issue with float64 intrinsic - # TODO(masahi): Eliminate this concern by adding TIR level CSE - msg = """ - Casting the dynamic input size to float32 before computing log2 in exclusive scan. - This could result in a wrong output if the runtime value of the input size is very large. - """ - logging.warning(msg) - size_cast_dtype = "float32" - # The following algorithm performs parallel exclusive scan # Up Sweep of exclusive scan lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, size_cast_dtype))), - "int64", + tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int64" ) with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << l2_width From a70dd1da45a64f48c012ac3fd1519aded97575f1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 10 Apr 2021 18:57:28 +0900 Subject: [PATCH 6/8] use clz for ceil_log2 when compiling for vk --- python/tvm/topi/cuda/scan.py | 7 +++---- python/tvm/topi/cuda/sort.py | 10 +++------- python/tvm/topi/math.py | 22 ++++++++++++++++++++++ 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 5d3798e3d27b..6dbaf02191c8 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -23,7 +23,7 @@ from tvm.contrib.thrust import can_use_rocthrust, can_use_thrust from .. import tag -from ..math import cast +from ..math import cast, ceil_log2 from ..transform import expand_dims, reshape, squeeze, transpose from ..utils import ceil_div, get_const_int, prod, swap from .injective import schedule_injective_from_existing @@ -103,9 +103,8 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i # The following algorithm performs parallel exclusive scan # Up Sweep of exclusive scan - lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(scan_axis_size, "float64"))), "int64" - ) + lim = ceil_log2(scan_axis_size) + with ib.for_range(0, lim, dtype="int64") as l2_width: width = 2 << l2_width diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 93e4d3feccc7..25cc7a4e2cfb 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -23,7 +23,7 @@ from ..transform import strided_slice, transpose from .. import tag from ..utils import ceil_div, swap -from ..math import cast +from ..math import cast, ceil_log2 def _schedule_sort(outs): @@ -238,9 +238,7 @@ def compare(a, b): return out # Sort the lower levels of the merge using odd-even sort, it's fast for small inputs - lower_lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(block_size, "float64"))), "int64" - ) + lower_lim = ceil_log2(block_size) _odd_even_sort( ib, @@ -254,9 +252,7 @@ def compare(a, b): values_swap, ) - upper_lim = tvm.tir.generic.cast( - tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64" - ) + upper_lim = ceil_log2(size) def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, diag, step_count): first = ib.allocate("int64", (1,), name="first", scope="local") diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py index cf6fcbb88c7e..fb0ee3b04cd7 100644 --- a/python/tvm/topi/math.py +++ b/python/tvm/topi/math.py @@ -742,3 +742,25 @@ def fast_erf(x): The result. """ return cpp.fast_erf(x, x.dtype, tag.ELEMWISE) + + +def ceil_log2(x): + """TODO""" + if not isinstance(x, tvm.tir.PrimExpr): + x = tvm.tir.const(x) + + if "float" in x.dtype: + return tvm.tir.ceil(tvm.tir.log2(x)) + + if "vulkan" in tvm.target.Target.current().kind.name: + # SPIR-V does not support log2 on fp64. Instead, we compute ceil_log2 via clz + clz = tvm.tir.clz(x) + bits = int(x.dtype[-2:]) + ceil_log2 = tvm.tir.if_then_else(x & (x - 1) == 0, bits - clz - 1, bits - clz) + + if ceil_log2.dtype != x.dtype: + return cast(ceil_log2, x.dtype) + + return ceil_log2 + + return cast(tvm.tir.ceil(tvm.tir.log2(cast(x, "float64"))), x.dtype) From cc2c3f970da05d5d1d41d89bcbe7ca6f5da3d51f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 16 Apr 2021 22:29:48 +0900 Subject: [PATCH 7/8] add doc on ceil_log2 --- python/tvm/topi/math.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py index fb0ee3b04cd7..2785a7f39505 100644 --- a/python/tvm/topi/math.py +++ b/python/tvm/topi/math.py @@ -745,7 +745,20 @@ def fast_erf(x): def ceil_log2(x): - """TODO""" + """Compute integer ceil log2 with a special code path for vulkan + SPIR-V does not support log2 on fp64. Instead, we compute integer ceil_log2 via clz + intrinsic when the target is vulkan. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ if not isinstance(x, tvm.tir.PrimExpr): x = tvm.tir.const(x) @@ -753,7 +766,6 @@ def ceil_log2(x): return tvm.tir.ceil(tvm.tir.log2(x)) if "vulkan" in tvm.target.Target.current().kind.name: - # SPIR-V does not support log2 on fp64. Instead, we compute ceil_log2 via clz clz = tvm.tir.clz(x) bits = int(x.dtype[-2:]) ceil_log2 = tvm.tir.if_then_else(x & (x - 1) == 0, bits - clz - 1, bits - clz) From 53b78dfc731f22552db533a28f4f26495f105468 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 17 Apr 2021 04:18:13 +0900 Subject: [PATCH 8/8] fix pylint --- python/tvm/topi/math.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py index 2785a7f39505..d3ef3daf10dc 100644 --- a/python/tvm/topi/math.py +++ b/python/tvm/topi/math.py @@ -768,11 +768,11 @@ def ceil_log2(x): if "vulkan" in tvm.target.Target.current().kind.name: clz = tvm.tir.clz(x) bits = int(x.dtype[-2:]) - ceil_log2 = tvm.tir.if_then_else(x & (x - 1) == 0, bits - clz - 1, bits - clz) + res = tvm.tir.if_then_else(x & (x - 1) == 0, bits - clz - 1, bits - clz) - if ceil_log2.dtype != x.dtype: - return cast(ceil_log2, x.dtype) + if res.dtype != x.dtype: + return cast(res, x.dtype) - return ceil_log2 + return res return cast(tvm.tir.ceil(tvm.tir.log2(cast(x, "float64"))), x.dtype)