diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 5438aecd753e..ee5a6796c994 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1387,6 +1387,9 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { op->op.same_as(builtin::end_profile_intrinsic())) { LOG(INFO) << "Ignoring profile_intrinsic ... " << op->op; return nullptr; + } else if (op->op.same_as(builtin::assume())) { + llvm::Value* cond = MakeValue(op->args[0]); + return builder_->CreateAssumption(cond); } else { LOG(FATAL) << "unknown intrinsic " << op->op; } diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index bef682435ebf..44f950c82ad3 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -978,5 +978,29 @@ def test_llvm_target_attributes(): assert n in functions_with_target +@tvm.testing.requires_llvm +def test_llvm_assume(): + """ + Check that LLVM does not error out when generating code with tir.assume. + Verifying for llvm.assume being generated is not easy as the intrinsic and its + related instructions get removed during optimizations + """ + + @T.prim_func + def tir_assume_func(A: T.Buffer((4, 4), "int32"), B: T.Buffer((14,), "int32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A_1 = T.Buffer((16,), "int32", data=A.data) + for axis0, axis1 in T.grid(4, 4): + T.assume(axis0 < 3 or axis1 < 2 or A_1[axis0 * 4 + axis1] == 0) + for i in range(14): + B_1 = T.Buffer((14,), "int32", data=B.data) + B_1[i] = A_1[i] * 2 + + mod = tvm.IRModule.from_expr(tir_assume_func) + inp = te.placeholder((4, 4), name="A", dtype="int32") + out = te.placeholder((14,), name="B", dtype="int32") + m = tvm.build(mod, [inp, out], target="llvm") + + if __name__ == "__main__": tvm.testing.main()