From 3f4eff15332f77daae17c079cd92e655a5590646 Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Tue, 16 Apr 2024 16:08:13 +0100 Subject: [PATCH 1/3] [SVE] Check for SVE target in VectorizeLoop Check that we are compiling for an SVE enabled target when the extent of a loop marked for vectorizing has a vscale dependent extent. --- src/driver/driver_api.cc | 4 + src/tir/transforms/vectorize_loop.cc | 25 +++- .../test_tir_transform_vectorize.py | 109 +++++++++++++----- 3 files changed, 106 insertions(+), 32 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7ea5032fa0cc..e88137989969 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -161,6 +161,7 @@ Array CreatePassList(bool disable_loop_partition) { .value(); bool instrument_lwp = pass_ctx->GetConfig("tir.instrument_lwp", Bool(false)).value(); + Target current_target = Target::Current(); Array user_lower_phase0 = Array(); Array user_lower_phase1 = Array(); @@ -196,6 +197,9 @@ Array CreatePassList(bool disable_loop_partition) { Array pass_list = user_lower_phase0; // PHASE 1 + if (current_target.defined()) { + pass_list.push_back(tir::transform::BindTarget(current_target)); + } pass_list.push_back(tir::transform::InjectPrefetch()); pass_list.push_back(tir::transform::TextureFlatten()); pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index a9cc4975801a..541ec80bbccf 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -34,6 +34,9 @@ #include #include +#include "../../src/arith/scalable_expression.h" +#include "../../tir/analysis/check_contains.h" + namespace tvm { namespace tir { @@ -725,17 +728,33 @@ class Vectorizer : public StmtMutator, public ExprFunctorattrs.GetAttr(tvm::attr::kTarget); + if (target.defined()) { + target_ = Downcast(target); + has_sve_ = target_->GetFeature("has_sve").value_or(Bool(false)); + } + } + Stmt VisitStmt_(const ForNode* op) final { if (op->kind == ForKind::kVectorized) { + auto* extent_as_int = op->extent.as(); + if (!extent_as_int || extent_as_int->value < 1) { + bool is_scalable_expr = CheckContains::ExprContains(op->extent, arith::IsVScaleCall); + ICHECK(is_scalable_expr && has_sve_) + << "Failed to vectorize loop with extent " << op->extent << " for target " << target_; + } ICHECK(is_zero(op->min)); return Vectorizer(op->loop_var, op->extent)(op->body); } else { return StmtMutator::VisitStmt_(op); } } -}; -Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer()(std::move(stmt)); } + private: + bool has_sve_{false}; + Target target_{}; +}; class VectorizeSkipper : public StmtMutator { public: @@ -759,7 +778,7 @@ Pass VectorizeLoop(bool enable_vectorize) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); if (enable_vectorize) { - n->body = LoopVectorizer()(std::move(n->body)); + n->body = LoopVectorizer(f)(std::move(n->body)); } else { n->body = VectorizeSkipper()(std::move(n->body)); } diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index dbca006b19cb..4dce7def8604 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -22,12 +22,17 @@ import pytest -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_loop(extent): +simple_target = tvm.target.Target("llvm -mtriple=x86_64-linux-gnu") +sve_target = tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+sve") + + +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_loop(extent, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((16,), "float32")): + T.func_attr({"target": target}) for j in T.vectorized(0, extent): A[j] = 1 @@ -35,6 +40,7 @@ def main(A: T.Buffer((16,), "float32")): class After: @T.prim_func def main(A: T.Buffer((16,), "float32")): + T.func_attr({"target": target}) A[T.Ramp(0, 1, extent)] = T.Broadcast(1, extent) mod = tvm.tir.transform.VectorizeLoop()(Before) @@ -66,6 +72,7 @@ def test_vectorize_vector_scalable_error(): class Module: @T.prim_func def main(A: T.Buffer((25,), "float32")): + T.func_attr({"target": sve_target}) for j in T.vectorized(T.vscale() * 4): A[j * 4 : j * 4 + 4] = T.Broadcast(T.float32(1), 4) @@ -99,7 +106,8 @@ def main(A: T.Buffer((25,), "float32")): error_msg = f"Vectorizing over existing scalable vectors is not supported." with pytest.raises(tvm.error.InternalError, match=error_msg): - tvm.tir.transform.VectorizeLoop()(Module) + with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"): + tvm.tir.transform.VectorizeLoop()(Module) def test_vectorize_vector_scalable_error4(): @@ -107,6 +115,7 @@ def test_vectorize_vector_scalable_error4(): class Module: @T.prim_func(private=True) def main(A: T.Buffer((25,), "float32")): + T.func_attr({"target": sve_target}) for j in T.vectorized(T.vscale() * 4): A[j * T.vscale() * 4 : j * T.vscale() * 4 + T.vscale() * 4] = T.Broadcast( T.float32(1), T.vscale() * 4 @@ -114,15 +123,17 @@ def main(A: T.Buffer((25,), "float32")): error_msg = f"Creating scalable vectors from existing vectors is not supported." with pytest.raises(tvm.error.InternalError, match=error_msg): - tvm.tir.transform.VectorizeLoop()(Module) + with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"): + tvm.tir.transform.VectorizeLoop()(Module) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_with_if(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_with_if(extent, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): + T.func_attr({"target": target}) for i in T.vectorized(extent): if x < n: A[i] = A[i] + T.float32(1) @@ -134,6 +145,7 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): class After: @T.prim_func def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): + T.func_attr({"target": target}) if x < n: A[T.Ramp(0, 1, extent)] = A[T.Ramp(0, 1, extent)] + T.Broadcast( T.float32(1), extent @@ -157,12 +169,13 @@ def test_vectorize_with_if_cond_int64(): f = tvm.build(s, [A, B], "llvm") -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_let(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_let(extent, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) for i in T.vectorized(extent): v = A[i] + T.float32(1) A[i] = v + T.float32(2) @@ -171,6 +184,7 @@ def main(A: T.Buffer((25,), "float32")): class After: @T.prim_func def main(A: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) v = A[T.Ramp(0, 1, extent)] + T.Broadcast(T.float32(1), extent) A[T.Ramp(0, 1, extent)] = v + T.Broadcast(T.float32(2), extent) @@ -178,8 +192,8 @@ def main(A: T.Buffer((25,), "float32")): tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent", (4, tvm.tir.vscale() * 4)) -def test_vectorize_with_le_cond(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)]) +def test_vectorize_with_le_cond(extent, target): n = te.var("n") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") @@ -188,15 +202,16 @@ def test_vectorize_with_le_cond(extent): A[i] = A[i] + 1 stmt = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) + func = tvm.tir.PrimFunc([A, n], stmt).with_attr("target", target) + mod = tvm.IRModule.from_expr(func) stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body # Check that the loop was't vectorised assert isinstance(stmt, tvm.tir.For) -@pytest.mark.parametrize("extent", (4, tvm.tir.vscale() * 4)) -def test_vectorize_with_ge_cond(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)]) +def test_vectorize_with_ge_cond(extent, target): n = te.var("n") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") @@ -205,19 +220,21 @@ def test_vectorize_with_ge_cond(extent): A[i] = A[i] + 1 stmt = ib.get() - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) + func = tvm.tir.PrimFunc([A, n], stmt).with_attr("target", target) + mod = tvm.IRModule.from_expr(func) stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body # Check that the loop wasn't vectorised assert isinstance(stmt, tvm.tir.For) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_if_then_else_scalarize(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_if_then_else_scalarize(extent, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) for i in T.vectorized(extent): A[i] = T.if_then_else(i > 0, A[i] + T.float32(1), A[i]) @@ -225,6 +242,7 @@ def main(A: T.Buffer((25,), "float32")): class After: @T.prim_func def main(A: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) for i_s in range(extent): A[i_s] = T.if_then_else(i_s > 0, A[i_s] + T.float32(1), A[i_s]) @@ -232,12 +250,13 @@ def main(A: T.Buffer((25,), "float32")): tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_if_then_else_vector(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_if_then_else_vector(extent, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((25,), "float32"), n: T.int32): + T.func_attr({"target": target}) for i in range(n): for j in T.vectorized(extent): A[i * extent + j] = T.if_then_else(i > 0, A[i * extent + j], 0) @@ -246,6 +265,7 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32): class After: @T.prim_func def main(A: T.Buffer((25,), "float32"), n: T.int32): + T.func_attr({"target": target}) for i in range(n): A[T.Ramp(i * extent, 1, extent)] = T.if_then_else( i > 0, A[T.Ramp(i * extent, 1, extent)], T.Broadcast(0, extent) @@ -311,13 +331,15 @@ def test_vectorize_dtype_mismatch(): @pytest.mark.parametrize( - "extent, vec_str", [(16, "float32x16"), (T.vscale() * 8, "float32xvscalex8")] + "extent, vec_str, target", + [(16, "float32x16", simple_target), (T.vscale() * 8, "float32xvscalex8", sve_target)], ) -def test_vectorize_with_reinterpret(extent, vec_str): +def test_vectorize_with_reinterpret(extent, vec_str, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): + T.func_attr({"target": target}) for i in T.vectorized(0, extent): B[i] = T.reinterpret("float32", A[i]) @@ -325,13 +347,14 @@ def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): class After: @T.prim_func def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): + T.func_attr({"target": target}) B[T.Ramp(0, 1, extent)] = T.reinterpret(vec_str, A[T.Ramp(0, 1, extent)]) mod = tvm.tir.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @pytest.mark.parametrize( "op", ( @@ -352,11 +375,12 @@ def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): T.NE, ), ) -def test_vectorize_binary(op, extent): +def test_vectorize_binary(op, extent, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) for j in T.vectorized(extent): A[j] = op(T.float32(3), B[j]) @@ -364,19 +388,21 @@ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): class After: @T.prim_func def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.float32(3), extent), B[T.Ramp(0, 1, extent)]) mod = tvm.tir.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @pytest.mark.parametrize("op", (T.And, T.Or)) -def test_vectorize_logical(op, extent): +def test_vectorize_logical(op, extent, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")): + T.func_attr({"target": target}) for j in T.vectorized(extent): A[j] = op(T.bool(1), B[j]) @@ -384,18 +410,20 @@ def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")): class After: @T.prim_func def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")): + T.func_attr({"target": target}) A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.bool(1), extent), B[T.Ramp(0, 1, extent)]) mod = tvm.tir.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent", (4, T.vscale() * 4)) -def test_vectorize_select(extent): +@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) +def test_vectorize_select(extent, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) for j in T.vectorized(extent): A[j] = T.Select(T.bool(True), A[j], B[j]) @@ -403,6 +431,7 @@ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): class After: @T.prim_func def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) A[T.Ramp(0, 1, extent)] = T.Select( T.Broadcast(T.bool(True), extent), A[T.Ramp(0, 1, extent)], @@ -413,12 +442,16 @@ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): tvm.ir.assert_structural_equal(mod, After) -@pytest.mark.parametrize("extent, vec_str", [(4, "int32x4"), (T.vscale() * 4, "int32xvscalex4")]) -def test_vectorize_cast(extent, vec_str): +@pytest.mark.parametrize( + "extent, vec_str, target", + [(4, "int32x4", simple_target), (T.vscale() * 4, "int32xvscalex4", sve_target)], +) +def test_vectorize_cast(extent, vec_str, target): @I.ir_module class Before: @T.prim_func def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) for j in T.vectorized(extent): A[j] = T.Cast("int32", B[j]) @@ -426,6 +459,7 @@ def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): class After: @T.prim_func def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): + T.func_attr({"target": target}) A[T.Ramp(0, 1, extent)] = T.Cast(vec_str, B[T.Ramp(0, 1, extent)]) mod = tvm.tir.transform.VectorizeLoop()(Before) @@ -441,10 +475,27 @@ def main(A: T.Buffer((25,), "int32")): for j in T.vectorized(n): A[j] = 3 - error_msg = f"Invalid expression for scalable lanes n" + error_msg = f"Failed to vectorize loop with extent n for target \\(nullptr\\)" with pytest.raises(tvm.error.InternalError, match=error_msg): tvm.tir.transform.VectorizeLoop()(Mod) +def test_illegal_vscale_in_non_sve_compilation(): + @I.ir_module + class Mod: + @T.prim_func + def main(A: T.Buffer((16,), "float32")): + T.func_attr({"target": simple_target}) + for j in T.vectorized(0, 4 * T.vscale()): + A[j] = 13 + + msg = ( + f"Failed to vectorize loop with extent T.vscale\\(\\) \\* 4 for target " + f"llvm -keys=cpu -mtriple=x86_64-linux-gnu" + ) + with pytest.raises(tvm.error.InternalError, match=msg): + tvm.tir.transform.VectorizeLoop()(Mod) + + if __name__ == "__main__": tvm.testing.main() From b1ab8797039edf1c0a1b671c9f4a8de42cad1a32 Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Thu, 18 Apr 2024 10:28:02 +0100 Subject: [PATCH 2/3] Use Target::Current() Use Target::Current() in LoopVectorizer to check for SVE Change-Id: I15363bad540d6752d6c2098c93efce25c107309b --- src/driver/driver_api.cc | 4 - src/tir/transforms/vectorize_loop.cc | 23 ++-- .../test_tir_transform_vectorize.py | 108 ++++++++---------- 3 files changed, 58 insertions(+), 77 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index e88137989969..7ea5032fa0cc 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -161,7 +161,6 @@ Array CreatePassList(bool disable_loop_partition) { .value(); bool instrument_lwp = pass_ctx->GetConfig("tir.instrument_lwp", Bool(false)).value(); - Target current_target = Target::Current(); Array user_lower_phase0 = Array(); Array user_lower_phase1 = Array(); @@ -197,9 +196,6 @@ Array CreatePassList(bool disable_loop_partition) { Array pass_list = user_lower_phase0; // PHASE 1 - if (current_target.defined()) { - pass_list.push_back(tir::transform::BindTarget(current_target)); - } pass_list.push_back(tir::transform::InjectPrefetch()); pass_list.push_back(tir::transform::TextureFlatten()); pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 541ec80bbccf..502a8d413774 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -728,21 +728,18 @@ class Vectorizer : public StmtMutator, public ExprFunctorattrs.GetAttr(tvm::attr::kTarget); - if (target.defined()) { - target_ = Downcast(target); - has_sve_ = target_->GetFeature("has_sve").value_or(Bool(false)); - } - } - Stmt VisitStmt_(const ForNode* op) final { if (op->kind == ForKind::kVectorized) { auto* extent_as_int = op->extent.as(); if (!extent_as_int || extent_as_int->value < 1) { + Target current_target = Target::Current(); + bool has_sve{false}; + if (current_target.defined()) { + has_sve = current_target->GetFeature("has_sve").value_or(Bool(false)); + } bool is_scalable_expr = CheckContains::ExprContains(op->extent, arith::IsVScaleCall); - ICHECK(is_scalable_expr && has_sve_) - << "Failed to vectorize loop with extent " << op->extent << " for target " << target_; + ICHECK(is_scalable_expr && has_sve) << "Failed to vectorize loop with extent " << op->extent + << " for target " << current_target; } ICHECK(is_zero(op->min)); return Vectorizer(op->loop_var, op->extent)(op->body); @@ -750,10 +747,6 @@ class LoopVectorizer : public StmtMutator { return StmtMutator::VisitStmt_(op); } } - - private: - bool has_sve_{false}; - Target target_{}; }; class VectorizeSkipper : public StmtMutator { @@ -778,7 +771,7 @@ Pass VectorizeLoop(bool enable_vectorize) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); if (enable_vectorize) { - n->body = LoopVectorizer(f)(std::move(n->body)); + n->body = LoopVectorizer()(std::move(n->body)); } else { n->body = VectorizeSkipper()(std::move(n->body)); } diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index 4dce7def8604..cd15c20a8aa1 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -32,7 +32,6 @@ def test_vectorize_loop(extent, target): class Before: @T.prim_func def main(A: T.Buffer((16,), "float32")): - T.func_attr({"target": target}) for j in T.vectorized(0, extent): A[j] = 1 @@ -40,11 +39,11 @@ def main(A: T.Buffer((16,), "float32")): class After: @T.prim_func def main(A: T.Buffer((16,), "float32")): - T.func_attr({"target": target}) A[T.Ramp(0, 1, extent)] = T.Broadcast(1, extent) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) def test_vectorize_vector(): @@ -72,13 +71,13 @@ def test_vectorize_vector_scalable_error(): class Module: @T.prim_func def main(A: T.Buffer((25,), "float32")): - T.func_attr({"target": sve_target}) for j in T.vectorized(T.vscale() * 4): A[j * 4 : j * 4 + 4] = T.Broadcast(T.float32(1), 4) error_msg = f"Creating scalable vectors from existing vectors is not supported." - with pytest.raises(tvm.error.InternalError, match=error_msg): - tvm.tir.transform.VectorizeLoop()(Module) + with tvm.target.Target(sve_target): + with pytest.raises(tvm.error.InternalError, match=error_msg): + tvm.tir.transform.VectorizeLoop()(Module) def test_vectorize_vector_scalable_error2(): @@ -106,7 +105,7 @@ def main(A: T.Buffer((25,), "float32")): error_msg = f"Vectorizing over existing scalable vectors is not supported." with pytest.raises(tvm.error.InternalError, match=error_msg): - with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"): + with tvm.target.Target(sve_target): tvm.tir.transform.VectorizeLoop()(Module) @@ -123,7 +122,7 @@ def main(A: T.Buffer((25,), "float32")): error_msg = f"Creating scalable vectors from existing vectors is not supported." with pytest.raises(tvm.error.InternalError, match=error_msg): - with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"): + with tvm.target.Target(sve_target): tvm.tir.transform.VectorizeLoop()(Module) @@ -133,7 +132,6 @@ def test_vectorize_with_if(extent, target): class Before: @T.prim_func def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): - T.func_attr({"target": target}) for i in T.vectorized(extent): if x < n: A[i] = A[i] + T.float32(1) @@ -145,7 +143,6 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): class After: @T.prim_func def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): - T.func_attr({"target": target}) if x < n: A[T.Ramp(0, 1, extent)] = A[T.Ramp(0, 1, extent)] + T.Broadcast( T.float32(1), extent @@ -155,8 +152,9 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32): if i_s < n: A[i_s] = T.float32(2) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) def test_vectorize_with_if_cond_int64(): @@ -175,7 +173,6 @@ def test_vectorize_let(extent, target): class Before: @T.prim_func def main(A: T.Buffer((25,), "float32")): - T.func_attr({"target": target}) for i in T.vectorized(extent): v = A[i] + T.float32(1) A[i] = v + T.float32(2) @@ -184,12 +181,12 @@ def main(A: T.Buffer((25,), "float32")): class After: @T.prim_func def main(A: T.Buffer((25,), "float32")): - T.func_attr({"target": target}) v = A[T.Ramp(0, 1, extent)] + T.Broadcast(T.float32(1), extent) A[T.Ramp(0, 1, extent)] = v + T.Broadcast(T.float32(2), extent) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) @pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)]) @@ -202,12 +199,13 @@ def test_vectorize_with_le_cond(extent, target): A[i] = A[i] + 1 stmt = ib.get() - func = tvm.tir.PrimFunc([A, n], stmt).with_attr("target", target) - mod = tvm.IRModule.from_expr(func) - stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) - # Check that the loop was't vectorised - assert isinstance(stmt, tvm.tir.For) + with tvm.target.Target(target): + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + + # Check that the loop was't vectorised + assert isinstance(stmt, tvm.tir.For) @pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)]) @@ -220,12 +218,13 @@ def test_vectorize_with_ge_cond(extent, target): A[i] = A[i] + 1 stmt = ib.get() - func = tvm.tir.PrimFunc([A, n], stmt).with_attr("target", target) - mod = tvm.IRModule.from_expr(func) - stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) - # Check that the loop wasn't vectorised - assert isinstance(stmt, tvm.tir.For) + with tvm.target.Target(target): + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + + # Check that the loop wasn't vectorised + assert isinstance(stmt, tvm.tir.For) @pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @@ -234,7 +233,6 @@ def test_vectorize_if_then_else_scalarize(extent, target): class Before: @T.prim_func def main(A: T.Buffer((25,), "float32")): - T.func_attr({"target": target}) for i in T.vectorized(extent): A[i] = T.if_then_else(i > 0, A[i] + T.float32(1), A[i]) @@ -242,12 +240,12 @@ def main(A: T.Buffer((25,), "float32")): class After: @T.prim_func def main(A: T.Buffer((25,), "float32")): - T.func_attr({"target": target}) for i_s in range(extent): A[i_s] = T.if_then_else(i_s > 0, A[i_s] + T.float32(1), A[i_s]) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) @pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @@ -256,7 +254,6 @@ def test_vectorize_if_then_else_vector(extent, target): class Before: @T.prim_func def main(A: T.Buffer((25,), "float32"), n: T.int32): - T.func_attr({"target": target}) for i in range(n): for j in T.vectorized(extent): A[i * extent + j] = T.if_then_else(i > 0, A[i * extent + j], 0) @@ -265,14 +262,14 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32): class After: @T.prim_func def main(A: T.Buffer((25,), "float32"), n: T.int32): - T.func_attr({"target": target}) for i in range(n): A[T.Ramp(i * extent, 1, extent)] = T.if_then_else( i > 0, A[T.Ramp(i * extent, 1, extent)], T.Broadcast(0, extent) ) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) def test_vectorize_while_fail(): @@ -339,7 +336,6 @@ def test_vectorize_with_reinterpret(extent, vec_str, target): class Before: @T.prim_func def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): - T.func_attr({"target": target}) for i in T.vectorized(0, extent): B[i] = T.reinterpret("float32", A[i]) @@ -347,11 +343,11 @@ def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): class After: @T.prim_func def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): - T.func_attr({"target": target}) B[T.Ramp(0, 1, extent)] = T.reinterpret(vec_str, A[T.Ramp(0, 1, extent)]) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) @pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @@ -380,7 +376,6 @@ def test_vectorize_binary(op, extent, target): class Before: @T.prim_func def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): - T.func_attr({"target": target}) for j in T.vectorized(extent): A[j] = op(T.float32(3), B[j]) @@ -388,11 +383,11 @@ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): class After: @T.prim_func def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): - T.func_attr({"target": target}) A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.float32(3), extent), B[T.Ramp(0, 1, extent)]) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) @pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @@ -402,7 +397,6 @@ def test_vectorize_logical(op, extent, target): class Before: @T.prim_func def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")): - T.func_attr({"target": target}) for j in T.vectorized(extent): A[j] = op(T.bool(1), B[j]) @@ -410,11 +404,11 @@ def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")): class After: @T.prim_func def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")): - T.func_attr({"target": target}) A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.bool(1), extent), B[T.Ramp(0, 1, extent)]) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) @pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @@ -423,7 +417,6 @@ def test_vectorize_select(extent, target): class Before: @T.prim_func def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): - T.func_attr({"target": target}) for j in T.vectorized(extent): A[j] = T.Select(T.bool(True), A[j], B[j]) @@ -431,15 +424,15 @@ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): class After: @T.prim_func def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): - T.func_attr({"target": target}) A[T.Ramp(0, 1, extent)] = T.Select( T.Broadcast(T.bool(True), extent), A[T.Ramp(0, 1, extent)], B[T.Ramp(0, 1, extent)], ) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) @pytest.mark.parametrize( @@ -451,7 +444,6 @@ def test_vectorize_cast(extent, vec_str, target): class Before: @T.prim_func def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): - T.func_attr({"target": target}) for j in T.vectorized(extent): A[j] = T.Cast("int32", B[j]) @@ -459,11 +451,11 @@ def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): class After: @T.prim_func def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): - T.func_attr({"target": target}) A[T.Ramp(0, 1, extent)] = T.Cast(vec_str, B[T.Ramp(0, 1, extent)]) - mod = tvm.tir.transform.VectorizeLoop()(Before) - tvm.ir.assert_structural_equal(mod, After) + with tvm.target.Target(target): + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) def test_illegal_extent(): @@ -485,7 +477,6 @@ def test_illegal_vscale_in_non_sve_compilation(): class Mod: @T.prim_func def main(A: T.Buffer((16,), "float32")): - T.func_attr({"target": simple_target}) for j in T.vectorized(0, 4 * T.vscale()): A[j] = 13 @@ -493,8 +484,9 @@ def main(A: T.Buffer((16,), "float32")): f"Failed to vectorize loop with extent T.vscale\\(\\) \\* 4 for target " f"llvm -keys=cpu -mtriple=x86_64-linux-gnu" ) - with pytest.raises(tvm.error.InternalError, match=msg): - tvm.tir.transform.VectorizeLoop()(Mod) + with tvm.target.Target(simple_target): + with pytest.raises(tvm.error.InternalError, match=msg): + tvm.tir.transform.VectorizeLoop()(Mod) if __name__ == "__main__": From acb0e7496bd78bbc8dfa28bfca2ccbfa44c7e82a Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Fri, 19 Apr 2024 16:22:27 +0100 Subject: [PATCH 3/3] Respond to review Change-Id: I0569534397a2d0db9587db6424b1674846a76079 --- src/arith/analyzer.cc | 7 ++----- src/arith/scalable_expression.cc | 9 +++++++++ src/arith/scalable_expression.h | 6 ++++++ src/tir/transforms/vectorize_loop.cc | 11 ++++------- .../tir-transform/test_tir_transform_vectorize.py | 1 - 5 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index b40670e4aa09..db39e4c0a42a 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -235,17 +235,14 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { // SVE, we can make some assumptions about the value of vscale and iterate over a // space of pre-defined values to attempt to prove the expression. if (tir::CheckContains::ExprContains(expr, IsVScaleCall)) { - Target curr_target = tvm::Target::Current(); - if (curr_target.defined() && curr_target->features.defined() && - (curr_target->features.find("has_sve") != curr_target->features.end()) && - curr_target->GetFeature("has_sve").value_or(Bool(false)).operator bool()) { + if (TargetHasSVE()) { return CanProveVscaleExpressionFromKnownValues(this, simplified, kAArch64VScaleValues); } LOG(WARNING) << "The expression contains scalable values. An attempt to prove by substituting " "with known values of vscale was not performed. This proof currently only supports " "AArch64 SVE targets, but the target was " - << curr_target; + << Target::Current(); } return false; } diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc index 38ec576ac297..0c5aea4e7da7 100644 --- a/src/arith/scalable_expression.cc +++ b/src/arith/scalable_expression.cc @@ -88,5 +88,14 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr return can_prove_expr; } +bool TargetHasSVE() { + Target current_target = Target::Current(); + bool has_sve{false}; + if (current_target.defined()) { + has_sve = current_target->GetFeature("has_sve").value_or(Bool(false)); + } + return has_sve; +} + } // namespace arith } // namespace tvm diff --git a/src/arith/scalable_expression.h b/src/arith/scalable_expression.h index e014f808f514..091783a59f8c 100644 --- a/src/arith/scalable_expression.h +++ b/src/arith/scalable_expression.h @@ -71,6 +71,12 @@ std::optional ExtractVscaleFactor(const PrimExpr& lanes); bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const PrimExpr& expr, const std::vector& vscale_values); +/*! + * \brief Check whether the compilation target supports SVE + * \return Whether SVE is supported + */ +bool TargetHasSVE(); + } // namespace arith } // namespace tvm diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 502a8d413774..3f5c07025044 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -731,15 +731,12 @@ class LoopVectorizer : public StmtMutator { Stmt VisitStmt_(const ForNode* op) final { if (op->kind == ForKind::kVectorized) { auto* extent_as_int = op->extent.as(); + if (!extent_as_int || extent_as_int->value < 1) { - Target current_target = Target::Current(); - bool has_sve{false}; - if (current_target.defined()) { - has_sve = current_target->GetFeature("has_sve").value_or(Bool(false)); - } bool is_scalable_expr = CheckContains::ExprContains(op->extent, arith::IsVScaleCall); - ICHECK(is_scalable_expr && has_sve) << "Failed to vectorize loop with extent " << op->extent - << " for target " << current_target; + ICHECK(is_scalable_expr && arith::TargetHasSVE()) + << "Failed to vectorize loop with extent " << op->extent << " for target " + << Target::Current(); } ICHECK(is_zero(op->min)); return Vectorizer(op->loop_var, op->extent)(op->body); diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index cd15c20a8aa1..de5453eb5c44 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -114,7 +114,6 @@ def test_vectorize_vector_scalable_error4(): class Module: @T.prim_func(private=True) def main(A: T.Buffer((25,), "float32")): - T.func_attr({"target": sve_target}) for j in T.vectorized(T.vscale() * 4): A[j * T.vscale() * 4 : j * T.vscale() * 4 + T.vscale() * 4] = T.Broadcast( T.float32(1), T.vscale() * 4