diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index b40670e4aa09..db39e4c0a42a 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -235,17 +235,14 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { // SVE, we can make some assumptions about the value of vscale and iterate over a // space of pre-defined values to attempt to prove the expression. if (tir::CheckContains::ExprContains(expr, IsVScaleCall)) { - Target curr_target = tvm::Target::Current(); - if (curr_target.defined() && curr_target->features.defined() && - (curr_target->features.find("has_sve") != curr_target->features.end()) && - curr_target->GetFeature("has_sve").value_or(Bool(false)).operator bool()) { + if (TargetHasSVE()) { return CanProveVscaleExpressionFromKnownValues(this, simplified, kAArch64VScaleValues); } LOG(WARNING) << "The expression contains scalable values. An attempt to prove by substituting " "with known values of vscale was not performed. This proof currently only supports " "AArch64 SVE targets, but the target was " - << curr_target; + << Target::Current(); } return false; } diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc index 38ec576ac297..0c5aea4e7da7 100644 --- a/src/arith/scalable_expression.cc +++ b/src/arith/scalable_expression.cc @@ -88,5 +88,14 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr return can_prove_expr; } +bool TargetHasSVE() { + Target current_target = Target::Current(); + bool has_sve{false}; + if (current_target.defined()) { + has_sve = current_target->GetFeature("has_sve").value_or(Bool(false)); + } + return has_sve; +} + } // namespace arith } // namespace tvm diff --git a/src/arith/scalable_expression.h b/src/arith/scalable_expression.h index e014f808f514..091783a59f8c 100644 --- a/src/arith/scalable_expression.h +++ b/src/arith/scalable_expression.h @@ -71,6 +71,12 @@ std::optional ExtractVscaleFactor(const PrimExpr& lanes); bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const PrimExpr& expr, const std::vector& vscale_values); +/*! + * \brief Check whether the compilation target supports SVE + * \return Whether SVE is supported + */ +bool TargetHasSVE(); + } // namespace arith } // namespace tvm diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index a9cc4975801a..3f5c07025044 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -34,6 +34,9 @@ #include #include +#include "../../src/arith/scalable_expression.h" +#include "../../tir/analysis/check_contains.h" + namespace tvm { namespace tir { @@ -727,6 +730,14 @@ class LoopVectorizer : public StmtMutator { public: Stmt VisitStmt_(const ForNode* op) final { if (op->kind == ForKind::kVectorized) { + auto* extent_as_int = op->extent.as(); + + if (!extent_as_int || extent_as_int->value < 1) { + bool is_scalable_expr = CheckContains::ExprContains(op->extent, arith::IsVScaleCall); + ICHECK(is_scalable_expr && arith::TargetHasSVE()) + << "Failed to vectorize loop with extent " << op->extent << " for target " + << Target::Current(); + } ICHECK(is_zero(op->min)); return Vectorizer(op->loop_var, op->extent)(op->body); } else { @@ -735,8 +746,6 @@ class LoopVectorizer : public StmtMutator { } }; -Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer()(std::move(stmt)); } - class VectorizeSkipper : public StmtMutator { public: Stmt VisitStmt_(const ForNode* op) final { diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index dbca006b19cb..de5453eb5c44 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -22,8 +22,12 @@ import pytest -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_loop(extent): +simple_target = tvm.target.Target("llvm -mtriple=x86_64-linux-gnu") +sve_target = tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+sve") + + +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_loop(extent, target): @I.ir_module class Before: @T.prim_func @@ -37,8 +41,9 @@ class After: def main(A: T.Buffer((16,), "float32")): A[T.Ramp(0, 1, extent)] = T.Broadcast(1, extent) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) def test_vectorize_vector(): @@ -70,8 +75,9 @@ def main(A: T.Buffer((25,), "float32")): A[j * 4 : j * 4 + 4] = T.Broadcast(T.float32(1), 4) error_msg = f"Creating scalable vectors from existing vectors is not supported." - with pytest.raises(tvm.error.InternalError, match=error_msg): - tvm.tir.transform.VectorizeLoop()(Module) + with tvm.target.Target(sve_target): + with pytest.raises(tvm.error.InternalError, match=error_msg): + tvm.tir.transform.VectorizeLoop()(Module) def test_vectorize_vector_scalable_error2(): @@ -99,7 +105,8 @@ def main(A: T.Buffer((25,), "float32")): error_msg = f"Vectorizing over existing scalable vectors is not supported." with pytest.raises(tvm.error.InternalError, match=error_msg): - tvm.tir.transform.VectorizeLoop()(Module) + with tvm.target.Target(sve_target): + tvm.tir.transform.VectorizeLoop()(Module) def test_vectorize_vector_scalable_error4(): @@ -114,11 +121,12 @@ def main(A: T.Buffer((25,), "float32")): error_msg = f"Creating scalable vectors from existing vectors is not supported." with pytest.raises(tvm.error.InternalError, match=error_msg): - tvm.tir.transform.VectorizeLoop()(Module) + with tvm.target.Target(sve_target): + tvm.tir.transform.VectorizeLoop()(Module) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_with_if(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_with_if(extent, target): @I.ir_module class Before: @T.prim_func @@ -143,8 +151,9 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): if i_s < n: A[i_s] = T.float32(2) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) def test_vectorize_with_if_cond_int64(): @@ -157,8 +166,8 @@ def test_vectorize_with_if_cond_int64(): f = tvm.build(s, [A, B], "llvm") -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_let(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_let(extent, target): @I.ir_module class Before: @T.prim_func @@ -174,12 +183,13 @@ def main(A: T.Buffer((25,), "float32")): v = A[T.Ramp(0, 1, extent)] + T.Broadcast(T.float32(1), extent) A[T.Ramp(0, 1, extent)] = v + T.Broadcast(T.float32(2), extent) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent", (4, tvm.tir.vscale() * 4)) -def test_vectorize_with_le_cond(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)]) +def test_vectorize_with_le_cond(extent, target): n = te.var("n") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") @@ -189,14 +199,16 @@ def test_vectorize_with_le_cond(extent): stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) - stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body - # Check that the loop was't vectorised - assert isinstance(stmt, tvm.tir.For) + with tvm.target.Target(target): + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + + # Check that the loop was't vectorised + assert isinstance(stmt, tvm.tir.For) -@pytest.mark.parametrize("extent", (4, tvm.tir.vscale() * 4)) -def test_vectorize_with_ge_cond(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)]) +def test_vectorize_with_ge_cond(extent, target): n = te.var("n") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") @@ -206,14 +218,16 @@ def test_vectorize_with_ge_cond(extent): stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) - stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body - # Check that the loop wasn't vectorised - assert isinstance(stmt, tvm.tir.For) + with tvm.target.Target(target): + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + # Check that the loop wasn't vectorised + assert isinstance(stmt, tvm.tir.For) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_if_then_else_scalarize(extent): + +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_if_then_else_scalarize(extent, target): @I.ir_module class Before: @T.prim_func @@ -228,12 +242,13 @@ def main(A: T.Buffer((25,), "float32")): for i_s in range(extent): A[i_s] = T.if_then_else(i_s > 0, A[i_s] + T.float32(1), A[i_s]) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_if_then_else_vector(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_if_then_else_vector(extent, target): @I.ir_module class Before: @T.prim_func @@ -251,8 +266,9 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32): i > 0, A[T.Ramp(i * extent, 1, extent)], T.Broadcast(0, extent) ) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) def test_vectorize_while_fail(): @@ -311,9 +327,10 @@ def test_vectorize_dtype_mismatch(): @pytest.mark.parametrize( - "extent, vec_str", [(16, "float32x16"), (T.vscale() * 8, "float32xvscalex8")] + "extent, vec_str, target", + [(16, "float32x16", simple_target), (T.vscale() * 8, "float32xvscalex8", sve_target)], ) -def test_vectorize_with_reinterpret(extent, vec_str): +def test_vectorize_with_reinterpret(extent, vec_str, target): @I.ir_module class Before: @T.prim_func @@ -327,11 +344,12 @@ class After: def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): B[T.Ramp(0, 1, extent)] = T.reinterpret(vec_str, A[T.Ramp(0, 1, extent)]) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @pytest.mark.parametrize( "op", ( @@ -352,7 +370,7 @@ def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): T.NE, ), ) -def test_vectorize_binary(op, extent): +def test_vectorize_binary(op, extent, target): @I.ir_module class Before: @T.prim_func @@ -366,13 +384,14 @@ class After: def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.float32(3), extent), B[T.Ramp(0, 1, extent)]) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @pytest.mark.parametrize("op", (T.And, T.Or)) -def test_vectorize_logical(op, extent): +def test_vectorize_logical(op, extent, target): @I.ir_module class Before: @T.prim_func @@ -386,12 +405,13 @@ class After: def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")): A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.bool(1), extent), B[T.Ramp(0, 1, extent)]) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_select(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_select(extent, target): @I.ir_module class Before: @T.prim_func @@ -409,12 +429,16 @@ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): B[T.Ramp(0, 1, extent)], ) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent, vec_str", [(4, "int32x4"), (T.vscale() * 4, "int32xvscalex4")]) -def test_vectorize_cast(extent, vec_str): +@pytest.mark.parametrize( + "extent, vec_str, target", + [(4, "int32x4", simple_target), (T.vscale() * 4, "int32xvscalex4", sve_target)], +) +def test_vectorize_cast(extent, vec_str, target): @I.ir_module class Before: @T.prim_func @@ -428,8 +452,9 @@ class After: def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): A[T.Ramp(0, 1, extent)] = T.Cast(vec_str, B[T.Ramp(0, 1, extent)]) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) def test_illegal_extent(): @@ -441,10 +466,27 @@ def main(A: T.Buffer((25,), "int32")): for j in T.vectorized(n): A[j] = 3 - error_msg = f"Invalid expression for scalable lanes n" + error_msg = f"Failed to vectorize loop with extent n for target \\(nullptr\\)" with pytest.raises(tvm.error.InternalError, match=error_msg): tvm.tir.transform.VectorizeLoop()(Mod) +def test_illegal_vscale_in_non_sve_compilation(): + @I.ir_module + class Mod: + @T.prim_func + def main(A: T.Buffer((16,), "float32")): + for j in T.vectorized(0, 4 * T.vscale()): + A[j] = 13 + + msg = ( + f"Failed to vectorize loop with extent T.vscale\\(\\) \\* 4 for target " + f"llvm -keys=cpu -mtriple=x86_64-linux-gnu" + ) + with tvm.target.Target(simple_target): + with pytest.raises(tvm.error.InternalError, match=msg): + tvm.tir.transform.VectorizeLoop()(Mod) + + if __name__ == "__main__": tvm.testing.main()