2424#include < tvm/runtime/registry.h>
2525#include < tvm/tir/builtin.h>
2626#include < tvm/tir/expr.h>
27+ #include < tvm/tir/op.h>
2728
2829namespace tvm {
2930namespace codegen {
@@ -32,8 +33,9 @@ namespace spirv {
3233using namespace runtime ;
3334
3435// num_signature means number of arguments used to query signature
36+
3537template <unsigned id>
36- inline void DispatchGLSLPureIntrin (const TVMArgs& targs, TVMRetValue* rv) {
38+ PrimExpr CallGLSLIntrin (const TVMArgs& targs, TVMRetValue* rv) {
3739 PrimExpr e = targs[0 ];
3840 const tir::CallNode* call = e.as <tir::CallNode>();
3941 ICHECK (call != nullptr );
@@ -44,7 +46,12 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
4446 for (PrimExpr arg : call->args ) {
4547 cargs.push_back (arg);
4648 }
47- *rv = tir::Call (call->dtype , tir::builtin::call_spirv_pure_glsl450 (), cargs);
49+ return tir::Call (call->dtype , tir::builtin::call_spirv_pure_glsl450 (), cargs);
50+ }
51+
52+ template <unsigned id>
53+ inline void DispatchGLSLPureIntrin (const TVMArgs& targs, TVMRetValue* rv) {
54+ *rv = CallGLSLIntrin<id>(targs, rv);
4855}
4956
5057TVM_REGISTER_GLOBAL (" tvm.intrin.rule.vulkan.floor" )
@@ -76,6 +83,17 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow").set_body(DispatchGLSLPureIntri
7683
7784TVM_REGISTER_GLOBAL (" tvm.intrin.rule.vulkan.tanh" ).set_body(DispatchGLSLPureIntrin<GLSLstd450Tanh>);
7885
86+ TVM_REGISTER_GLOBAL (" tvm.intrin.rule.vulkan.clz" )
87+ .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
88+ PrimExpr e = targs[0 ];
89+ const tir::CallNode* call = e.as <tir::CallNode>();
90+ ICHECK (call != nullptr );
91+ ICHECK_EQ (call->args .size (), 1 );
92+ PrimExpr arg = call->args [0 ];
93+ PrimExpr msb = CallGLSLIntrin<GLSLstd450FindUMsb>(targs, rv);
94+ *rv = PrimExpr (arg.dtype ().bits () - 1 ) - msb;
95+ });
96+
7997// WebGPU rules.
8098TVM_REGISTER_GLOBAL (" tvm.intrin.rule.webgpu.floor" )
8199 .set_body(DispatchGLSLPureIntrin<GLSLstd450Floor>);
0 commit comments