diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index cf82eb07edf2..67d01aa92389 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -139,7 +139,8 @@ TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin) TIR_DEFINE_BUILTIN_FUNC(call_llvm_pure_intrin) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("TScriptDtypePrintLocation", - Integer(ScriptDtypePrintLocation::kFirst)); + Integer(ScriptDtypePrintLocation::kFirst)) + .set_attr("TVectorizable", true); TIR_DEFINE_BUILTIN_FUNC(call_spirv_pure_glsl450) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 63569f342aed..b4e3d67e500e 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -550,7 +550,28 @@ class Vectorizer : public StmtMutator, public ExprFunctor new_args = MutateArray(op->args, &lane); + Array new_args; + if (op->op.same_as(builtin::call_llvm_pure_intrin())) { + // op->args[1], will give us total number of arguments to intrinsic + int num_signature = Downcast(op->args[1])->value; + Array op_expr_args; + for (int i = 0; i < num_signature; i++) { + // Collect all intrinsic arguments + op_expr_args.push_back(op->args[i + 2]); + } + // Generate RAMP nodes for intrinsic arguments + Array updated_args = MutateArray(op_expr_args, &lane); + // Collect Intrinsic ID and no. of argument + for (int i = 0; i < 2; i++) { + new_args.push_back(op->args[i]); + } + // Collect updated intrinsic arguments + for (int i = 0; i < num_signature; i++) { + new_args.push_back(updated_args[i]); + } + } else { + new_args = MutateArray(op->args, &lane); + } // normal code path. if (op->args.same_as(new_args)) { return GetRef(op); diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index 7523cab54941..9659d896aed8 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -790,5 +790,63 @@ def expected(a: T.handle, b: T.handle): tvm.ir.assert_structural_equal(after, expected) +@pytest.mark.parametrize( + "extent, vec_str, target", + [(4, "float32x4", simple_target)], +) +def test_vectorize_llvm_pure_intrin(extent, vec_str, target): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + for j in T.vectorized(extent): + A[j] = T.call_llvm_pure_intrin( + "float32", "llvm.sqrt", tvm.tir.const(1, "uint"), B[j] + ) + + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin( + vec_str, "llvm.sqrt", tvm.tir.const(1, "uint"), B[T.Ramp(0, 1, extent)] + ) + + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) + mod = tvm.build(mod, target) + + +@pytest.mark.parametrize( + "extent, vec_str, target", + [(4, "int32x4", simple_target)], +) +def test_vectorize_llvm_pure_intrin_fail(extent, vec_str, target): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): + for j in T.vectorized(extent): + A[j] = T.call_llvm_pure_intrin( + "int32", "llvm.lround", tvm.tir.const(1, "uint"), B[j] + ) + + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): + A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin( + vec_str, "llvm.lround", tvm.tir.const(1, "uint"), B[T.Ramp(0, 1, extent)] + ) + + with pytest.raises(Exception) as e_info: + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + ex = tvm.build(mod, target) + tvm.ir.assert_structural_equal(mod, After) + assert "Intrinsic does not support vectors" in e_info.value.args[0] + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index 9e77fa090021..8364e65a4178 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -1045,5 +1045,26 @@ def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): _assert_print(main, expected_output) +def test_vectorize_llvm_pure_intrin(): + from tvm.script import tir as T + + @T.prim_func + def main(a: T.handle, b: T.handle): + A = T.match_buffer(a, (4,), "float32") + B = T.match_buffer(b, (4,), "float32") + A[T.Ramp(0, 1, 4)] = T.call_llvm_pure_intrin( + "float32x4", "llvm.sqrt", 1, B[T.Ramp(0, 1, 4)] + ) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func +def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")): + A[0:4] = T.call_llvm_pure_intrin("float32x4", "llvm.sqrt", 1, B[0:4]) + """ + _assert_print(main, expected_output) + + if __name__ == "__main__": tvm.testing.main()