From d3be9abe25c6e2eaeeb7c498487e923cfc05de66 Mon Sep 17 00:00:00 2001 From: rutkoor Date: Fri, 10 May 2024 05:14:03 -0500 Subject: [PATCH 1/2] Vector-Codegen support for llvm-pure-intrin --- src/tir/op/builtin.cc | 3 +- src/tir/transforms/vectorize_loop.cc | 23 +++++++- .../test_tir_transform_vectorize.py | 58 +++++++++++++++++++ 3 files changed, 82 insertions(+), 2 deletions(-) diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index cf82eb07edf2..67d01aa92389 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -139,7 +139,8 @@ TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin) TIR_DEFINE_BUILTIN_FUNC(call_llvm_pure_intrin) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("TScriptDtypePrintLocation", - Integer(ScriptDtypePrintLocation::kFirst)); + Integer(ScriptDtypePrintLocation::kFirst)) + .set_attr("TVectorizable", true); TIR_DEFINE_BUILTIN_FUNC(call_spirv_pure_glsl450) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 3f5c07025044..c769e2dab2c4 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -429,7 +429,28 @@ class Vectorizer : public StmtMutator, public ExprFunctor new_args = MutateArray(op->args, &lane); + Array new_args; + if (op->op.same_as(builtin::call_llvm_pure_intrin())) { + // op->args[1], will give us total number of arguments to intrinsic + int num_signature = Downcast(op->args[1])->value; + Array op_expr_args; + for (int i = 0; i < num_signature; i++) { + // Collect all intrinsic arguments + op_expr_args.push_back(op->args[i + 2]); + } + // Generate RAMP nodes for intrinsic arguments + Array updated_args = MutateArray(op_expr_args, &lane); + // Collect Intrinsic ID and no. of argument + for (int i = 0; i < 2; i++) { + new_args.push_back(op->args[i]); + } + // Collect updated intrinsic arguments + for (int i = 0; i < num_signature; i++) { + new_args.push_back(updated_args[i]); + } + } else { + new_args = MutateArray(op->args, &lane); + } // normal code path. if (op->args.same_as(new_args)) { return GetRef(op); diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index de5453eb5c44..821beba3a886 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -488,5 +488,63 @@ def main(A: T.Buffer((16,), "float32")): tvm.tir.transform.VectorizeLoop()(Mod) +@pytest.mark.parametrize( + "extent, vec_str, target", + [(4, "float32x4", simple_target)], +) +def test_vectorize_llvm_pure_intrin(extent, vec_str, target): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + for j in T.vectorized(extent): + A[j] = T.call_llvm_pure_intrin( + "float32", "llvm.sqrt", tvm.tir.const(1, "uint"), B[j] + ) + + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin( + vec_str, "llvm.sqrt", tvm.tir.const(1, "uint"), B[T.Ramp(0, 1, extent)] + ) + + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) + mod = tvm.build(mod, target) + + +@pytest.mark.parametrize( + "extent, vec_str, target", + [(4, "int32x4", simple_target)], +) +def test_vectorize_llvm_pure_intrin_fail(extent, vec_str, target): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): + for j in T.vectorized(extent): + A[j] = T.call_llvm_pure_intrin( + "int32", "llvm.lround", tvm.tir.const(1, "uint"), B[j] + ) + + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): + A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin( + vec_str, "llvm.lround", tvm.tir.const(1, "uint"), B[T.Ramp(0, 1, extent)] + ) + + with pytest.raises(Exception) as e_info: + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + ex = tvm.build(mod, target) + tvm.ir.assert_structural_equal(mod, After) + assert "Intrinsic does not support vectors" in e_info.value.args[0] + + if __name__ == "__main__": tvm.testing.main() From 39d28384f3987ddc7db633fba4d3df13a03d1812 Mon Sep 17 00:00:00 2001 From: rutkoor Date: Mon, 3 Jun 2024 03:51:40 -0500 Subject: [PATCH 2/2] Adding test in test_tvmscript_printer_tir.py --- .../tvmscript/test_tvmscript_printer_tir.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index 9e77fa090021..8364e65a4178 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -1045,5 +1045,26 @@ def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): _assert_print(main, expected_output) +def test_vectorize_llvm_pure_intrin(): + from tvm.script import tir as T + + @T.prim_func + def main(a: T.handle, b: T.handle): + A = T.match_buffer(a, (4,), "float32") + B = T.match_buffer(b, (4,), "float32") + A[T.Ramp(0, 1, 4)] = T.call_llvm_pure_intrin( + "float32x4", "llvm.sqrt", 1, B[T.Ramp(0, 1, 4)] + ) + + expected_output = """ +# from tvm.script import tir as T + +@T.prim_func +def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")): + A[0:4] = T.call_llvm_pure_intrin("float32x4", "llvm.sqrt", 1, B[0:4]) + """ + _assert_print(main, expected_output) + + if __name__ == "__main__": tvm.testing.main()