Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin)
TIR_DEFINE_BUILTIN_FUNC(call_llvm_pure_intrin)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Integer(ScriptDtypePrintLocation::kFirst));
Integer(ScriptDtypePrintLocation::kFirst))
.set_attr<TVectorizable>("TVectorizable", true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can add a small testcase for the IR printer ?
tests/python/unittest/test_tvmscript_printer_tir.py


TIR_DEFINE_BUILTIN_FUNC(call_spirv_pure_glsl450)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
Expand Down
23 changes: 22 additions & 1 deletion src/tir/transforms/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,28 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
} else {
int lane = 0;
Array<PrimExpr> new_args = MutateArray(op->args, &lane);
Array<PrimExpr> new_args;
if (op->op.same_as(builtin::call_llvm_pure_intrin())) {
// op->args[1], will give us total number of arguments to intrinsic
int num_signature = Downcast<IntImm>(op->args[1])->value;
Array<PrimExpr> op_expr_args;
for (int i = 0; i < num_signature; i++) {
// Collect all intrinsic arguments
op_expr_args.push_back(op->args[i + 2]);
}
// Generate RAMP nodes for intrinsic arguments
Array<PrimExpr> updated_args = MutateArray(op_expr_args, &lane);
// Collect Intrinsic ID and no. of argument
for (int i = 0; i < 2; i++) {
new_args.push_back(op->args[i]);
}
// Collect updated intrinsic arguments
for (int i = 0; i < num_signature; i++) {
new_args.push_back(updated_args[i]);
}
} else {
new_args = MutateArray(op->args, &lane);
}
// normal code path.
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
Expand Down
58 changes: 58 additions & 0 deletions tests/python/tir-transform/test_tir_transform_vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,5 +790,63 @@ def expected(a: T.handle, b: T.handle):
tvm.ir.assert_structural_equal(after, expected)


@pytest.mark.parametrize(
"extent, vec_str, target",
[(4, "float32x4", simple_target)],
)
def test_vectorize_llvm_pure_intrin(extent, vec_str, target):
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
for j in T.vectorized(extent):
A[j] = T.call_llvm_pure_intrin(
"float32", "llvm.sqrt", tvm.tir.const(1, "uint"), B[j]
)

@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin(
vec_str, "llvm.sqrt", tvm.tir.const(1, "uint"), B[T.Ramp(0, 1, extent)]
)

with tvm.target.Target(target):
mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
mod = tvm.build(mod, target)


@pytest.mark.parametrize(
"extent, vec_str, target",
[(4, "int32x4", simple_target)],
)
def test_vectorize_llvm_pure_intrin_fail(extent, vec_str, target):
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
for j in T.vectorized(extent):
A[j] = T.call_llvm_pure_intrin(
"int32", "llvm.lround", tvm.tir.const(1, "uint"), B[j]
)

@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin(
vec_str, "llvm.lround", tvm.tir.const(1, "uint"), B[T.Ramp(0, 1, extent)]
)

with pytest.raises(Exception) as e_info:
with tvm.target.Target(target):
mod = tvm.tir.transform.VectorizeLoop()(Before)
ex = tvm.build(mod, target)
tvm.ir.assert_structural_equal(mod, After)
assert "Intrinsic does not support vectors" in e_info.value.args[0]


if __name__ == "__main__":
tvm.testing.main()
21 changes: 21 additions & 0 deletions tests/python/tvmscript/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,5 +1045,26 @@ def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")):
_assert_print(main, expected_output)


def test_vectorize_llvm_pure_intrin():
from tvm.script import tir as T

@T.prim_func
def main(a: T.handle, b: T.handle):
A = T.match_buffer(a, (4,), "float32")
B = T.match_buffer(b, (4,), "float32")
A[T.Ramp(0, 1, 4)] = T.call_llvm_pure_intrin(
"float32x4", "llvm.sqrt", 1, B[T.Ramp(0, 1, 4)]
)

expected_output = """
# from tvm.script import tir as T

@T.prim_func
def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")):
A[0:4] = T.call_llvm_pure_intrin("float32x4", "llvm.sqrt", 1, B[0:4])
"""
_assert_print(main, expected_output)


if __name__ == "__main__":
tvm.testing.main()