Skip to content

Commit cc79e8f

Browse files
authored
[TIR] Add a new intrinsic count leading zeros for LLVM and SPIR-V (#7825)
1 parent aa9cb63 commit cc79e8f

File tree

7 files changed

+93
-3
lines changed

7 files changed

+93
-3
lines changed

include/tvm/tir/op.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,7 @@ TVM_DECLARE_INTRIN_UNARY(atan);
864864
TVM_DECLARE_INTRIN_UNARY(acosh);
865865
TVM_DECLARE_INTRIN_UNARY(asinh);
866866
TVM_DECLARE_INTRIN_UNARY(atanh);
867+
TVM_DECLARE_INTRIN_UNARY(clz);
867868

868869
#define TVM_DECLARE_INTRIN_BINARY(OpName) \
869870
inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \

python/tvm/tir/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
from .op import call_packed, call_intrin, call_pure_extern, call_extern
3939
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
40-
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp
40+
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
4141
from .op import sin, sinh, asin, asinh
4242
from .op import cos, cosh, acos, acosh
4343
from .op import tan, tanh, atan, atan2, atanh

python/tvm/tir/op.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,22 @@ def rsqrt(x):
752752
return call_intrin(x.dtype, "tir.rsqrt", x)
753753

754754

755+
def clz(x):
756+
"""Count leading zero bits of an integer x.
757+
758+
Parameters
759+
----------
760+
x : PrimExpr
761+
Input argument. The result is undefined if the input is 0.
762+
763+
Returns
764+
-------
765+
y : PrimExpr
766+
The result.
767+
"""
768+
return call_intrin("int32", "tir.clz", x)
769+
770+
755771
def floor(x, span=None):
756772
"""Take floor of float input x.
757773

src/target/llvm/intrin_rule_llvm.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh")
160160
*rv = ret;
161161
});
162162

163+
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.clz").set_body([](const TVMArgs& targs, TVMRetValue* rv) {
164+
PrimExpr e = targs[0];
165+
const tir::CallNode* call = e.as<tir::CallNode>();
166+
ICHECK(call != nullptr);
167+
ICHECK_EQ(call->args.size(), 1);
168+
Array<PrimExpr> cargs;
169+
cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz));
170+
cargs.push_back(IntImm(DataType::UInt(32), 2));
171+
cargs.push_back(call->args[0]);
172+
cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef
173+
// LLVM requires that the return type must match the first argument type
174+
auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs);
175+
*rv = cast(call->dtype, clz);
176+
});
177+
163178
} // namespace llvm
164179
} // namespace codegen
165180
} // namespace tvm

src/target/spirv/intrin_rule_spirv.cc

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <tvm/runtime/registry.h>
2525
#include <tvm/tir/builtin.h>
2626
#include <tvm/tir/expr.h>
27+
#include <tvm/tir/op.h>
2728

2829
namespace tvm {
2930
namespace codegen {
@@ -32,8 +33,9 @@ namespace spirv {
3233
using namespace runtime;
3334

3435
// num_signature means number of arguments used to query signature
36+
3537
template <unsigned id>
36-
inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
38+
PrimExpr CallGLSLIntrin(const TVMArgs& targs, TVMRetValue* rv) {
3739
PrimExpr e = targs[0];
3840
const tir::CallNode* call = e.as<tir::CallNode>();
3941
ICHECK(call != nullptr);
@@ -44,7 +46,12 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
4446
for (PrimExpr arg : call->args) {
4547
cargs.push_back(arg);
4648
}
47-
*rv = tir::Call(call->dtype, tir::builtin::call_spirv_pure_glsl450(), cargs);
49+
return tir::Call(call->dtype, tir::builtin::call_spirv_pure_glsl450(), cargs);
50+
}
51+
52+
template <unsigned id>
53+
inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
54+
*rv = CallGLSLIntrin<id>(targs, rv);
4855
}
4956

5057
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
@@ -76,6 +83,17 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow").set_body(DispatchGLSLPureIntri
7683

7784
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh").set_body(DispatchGLSLPureIntrin<GLSLstd450Tanh>);
7885

86+
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.clz")
87+
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
88+
PrimExpr e = targs[0];
89+
const tir::CallNode* call = e.as<tir::CallNode>();
90+
ICHECK(call != nullptr);
91+
ICHECK_EQ(call->args.size(), 1);
92+
PrimExpr arg = call->args[0];
93+
PrimExpr msb = CallGLSLIntrin<GLSLstd450FindUMsb>(targs, rv);
94+
*rv = PrimExpr(arg.dtype().bits() - 1) - msb;
95+
});
96+
7997
// WebGPU rules.
8098
TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.floor")
8199
.set_body(DispatchGLSLPureIntrin<GLSLstd450Floor>);

src/tir/op/op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,8 @@ TIR_REGISTER_PURE_UNARY_OP("tir.asinh");
858858

859859
TIR_REGISTER_PURE_UNARY_OP("tir.atanh");
860860

861+
TIR_REGISTER_PURE_UNARY_OP("tir.clz");
862+
861863
// binary intrinsics
862864
TIR_REGISTER_PURE_BINARY_OP("tir.atan2");
863865

tests/python/unittest/test_tir_intrin.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,47 @@ def test_ldexp():
142142
)
143143

144144

145+
def test_clz():
146+
def clz_np(x, dtype):
147+
ceil_log2 = np.ceil(np.log2(x)).astype(dtype)
148+
bits = int(dtype[-2:])
149+
clz = bits - ceil_log2
150+
clz[np.bitwise_and(x, x - 1) == 0] -= 1
151+
return clz
152+
153+
for target in ["llvm", "vulkan"]:
154+
if not tvm.testing.device_enabled("vulkan"):
155+
continue
156+
157+
for dtype in ["int32", "int64"]:
158+
m = te.var("m")
159+
A = te.placeholder((m,), name="A", dtype=dtype)
160+
B = te.compute((m,), lambda *i: tvm.tir.clz(A(*i)), name="B")
161+
s = te.create_schedule(B.op)
162+
163+
if target == "vulkan":
164+
bx, tx = s[B].split(B.op.axis[0], factor=64)
165+
166+
s[B].bind(bx, te.thread_axis("blockIdx.x"))
167+
s[B].bind(tx, te.thread_axis("threadIdx.x"))
168+
169+
f = tvm.build(s, [A, B], target)
170+
dev = tvm.device(target, 0)
171+
n = 10
172+
173+
for high in [10, 100, 1000, 10000, 100000, 1000000]:
174+
a_np = np.random.randint(1, high=high, size=(n,)).astype(dtype)
175+
a = tvm.nd.array(a_np, dev)
176+
b = tvm.nd.array(np.zeros((n,)).astype("int32"), dev)
177+
f(a, b)
178+
ref = clz_np(a_np, dtype)
179+
np.testing.assert_equal(b.asnumpy(), ref)
180+
181+
145182
if __name__ == "__main__":
146183
test_nearbyint()
147184
test_unary_intrin()
148185
test_round_intrinsics_on_int()
149186
test_binary_intrin()
150187
test_ldexp()
188+
test_clz()

0 commit comments

Comments
 (0)