Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,17 +235,14 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
// SVE, we can make some assumptions about the value of vscale and iterate over a
// space of pre-defined values to attempt to prove the expression.
if (tir::CheckContains::ExprContains(expr, IsVScaleCall)) {
Target curr_target = tvm::Target::Current();
if (curr_target.defined() && curr_target->features.defined() &&
(curr_target->features.find("has_sve") != curr_target->features.end()) &&
curr_target->GetFeature<Bool>("has_sve").value_or(Bool(false)).operator bool()) {
if (TargetHasSVE()) {
return CanProveVscaleExpressionFromKnownValues(this, simplified, kAArch64VScaleValues);
}
LOG(WARNING)
<< "The expression contains scalable values. An attempt to prove by substituting "
"with known values of vscale was not performed. This proof currently only supports "
"AArch64 SVE targets, but the target was "
<< curr_target;
<< Target::Current();
}
return false;
}
Expand Down
9 changes: 9 additions & 0 deletions src/arith/scalable_expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,14 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr
return can_prove_expr;
}

bool TargetHasSVE() {
Target current_target = Target::Current();
bool has_sve{false};
if (current_target.defined()) {
has_sve = current_target->GetFeature<Bool>("has_sve").value_or(Bool(false));
}
return has_sve;
}

} // namespace arith
} // namespace tvm
6 changes: 6 additions & 0 deletions src/arith/scalable_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes);
bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const PrimExpr& expr,
const std::vector<unsigned int>& vscale_values);

/*!
* \brief Check whether the compilation target supports SVE
* \return Whether SVE is supported
*/
bool TargetHasSVE();

} // namespace arith
} // namespace tvm

Expand Down
13 changes: 11 additions & 2 deletions src/tir/transforms/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
#include <unordered_map>
#include <vector>

#include "../../src/arith/scalable_expression.h"
#include "../../tir/analysis/check_contains.h"

namespace tvm {
namespace tir {

Expand Down Expand Up @@ -727,6 +730,14 @@ class LoopVectorizer : public StmtMutator {
public:
Stmt VisitStmt_(const ForNode* op) final {
if (op->kind == ForKind::kVectorized) {
auto* extent_as_int = op->extent.as<IntImmNode>();

if (!extent_as_int || extent_as_int->value < 1) {
bool is_scalable_expr = CheckContains::ExprContains(op->extent, arith::IsVScaleCall);
ICHECK(is_scalable_expr && arith::TargetHasSVE())
<< "Failed to vectorize loop with extent " << op->extent << " for target "
<< Target::Current();
}
ICHECK(is_zero(op->min));
return Vectorizer(op->loop_var, op->extent)(op->body);
} else {
Expand All @@ -735,8 +746,6 @@ class LoopVectorizer : public StmtMutator {
}
};

Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer()(std::move(stmt)); }

class VectorizeSkipper : public StmtMutator {
public:
Stmt VisitStmt_(const ForNode* op) final {
Expand Down
152 changes: 97 additions & 55 deletions tests/python/tir-transform/test_tir_transform_vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@
import pytest


@pytest.mark.parametrize("extent", (4, T.vscale() * 4))
def test_vectorize_loop(extent):
simple_target = tvm.target.Target("llvm -mtriple=x86_64-linux-gnu")
sve_target = tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+sve")


@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_loop(extent, target):
@I.ir_module
class Before:
@T.prim_func
Expand All @@ -37,8 +41,9 @@ class After:
def main(A: T.Buffer((16,), "float32")):
A[T.Ramp(0, 1, extent)] = T.Broadcast(1, extent)

mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
with tvm.target.Target(target):
mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)


def test_vectorize_vector():
Expand Down Expand Up @@ -70,8 +75,9 @@ def main(A: T.Buffer((25,), "float32")):
A[j * 4 : j * 4 + 4] = T.Broadcast(T.float32(1), 4)

error_msg = f"Creating scalable vectors from existing vectors is not supported."
with pytest.raises(tvm.error.InternalError, match=error_msg):
tvm.tir.transform.VectorizeLoop()(Module)
with tvm.target.Target(sve_target):
with pytest.raises(tvm.error.InternalError, match=error_msg):
tvm.tir.transform.VectorizeLoop()(Module)


def test_vectorize_vector_scalable_error2():
Expand Down Expand Up @@ -99,7 +105,8 @@ def main(A: T.Buffer((25,), "float32")):

error_msg = f"Vectorizing over existing scalable vectors is not supported."
with pytest.raises(tvm.error.InternalError, match=error_msg):
tvm.tir.transform.VectorizeLoop()(Module)
with tvm.target.Target(sve_target):
tvm.tir.transform.VectorizeLoop()(Module)


def test_vectorize_vector_scalable_error4():
Expand All @@ -114,11 +121,12 @@ def main(A: T.Buffer((25,), "float32")):

error_msg = f"Creating scalable vectors from existing vectors is not supported."
with pytest.raises(tvm.error.InternalError, match=error_msg):
tvm.tir.transform.VectorizeLoop()(Module)
with tvm.target.Target(sve_target):
tvm.tir.transform.VectorizeLoop()(Module)


@pytest.mark.parametrize("extent", (4, T.vscale() * 4))
def test_vectorize_with_if(extent):
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_with_if(extent, target):
@I.ir_module
class Before:
@T.prim_func
Expand All @@ -143,8 +151,9 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32):
if i_s < n:
A[i_s] = T.float32(2)

mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
with tvm.target.Target(target):
mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)


def test_vectorize_with_if_cond_int64():
Expand All @@ -157,8 +166,8 @@ def test_vectorize_with_if_cond_int64():
f = tvm.build(s, [A, B], "llvm")


@pytest.mark.parametrize("extent", (4, T.vscale() * 4))
def test_vectorize_let(extent):
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_let(extent, target):
@I.ir_module
class Before:
@T.prim_func
Expand All @@ -174,12 +183,13 @@ def main(A: T.Buffer((25,), "float32")):
v = A[T.Ramp(0, 1, extent)] + T.Broadcast(T.float32(1), extent)
A[T.Ramp(0, 1, extent)] = v + T.Broadcast(T.float32(2), extent)

mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
with tvm.target.Target(target):
mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)


@pytest.mark.parametrize("extent", (4, tvm.tir.vscale() * 4))
def test_vectorize_with_le_cond(extent):
@pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)])
def test_vectorize_with_le_cond(extent, target):
n = te.var("n")
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")
Expand All @@ -189,14 +199,16 @@ def test_vectorize_with_le_cond(extent):
stmt = ib.get()

mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body

# Check that the loop was't vectorised
assert isinstance(stmt, tvm.tir.For)
with tvm.target.Target(target):
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body

# Check that the loop was't vectorised
assert isinstance(stmt, tvm.tir.For)


@pytest.mark.parametrize("extent", (4, tvm.tir.vscale() * 4))
def test_vectorize_with_ge_cond(extent):
@pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)])
def test_vectorize_with_ge_cond(extent, target):
n = te.var("n")
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")
Expand All @@ -206,14 +218,16 @@ def test_vectorize_with_ge_cond(extent):
stmt = ib.get()

mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body

# Check that the loop wasn't vectorised
assert isinstance(stmt, tvm.tir.For)
with tvm.target.Target(target):
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body

# Check that the loop wasn't vectorised
assert isinstance(stmt, tvm.tir.For)

@pytest.mark.parametrize("extent", (4, T.vscale() * 4))
def test_vectorize_if_then_else_scalarize(extent):

@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_if_then_else_scalarize(extent, target):
@I.ir_module
class Before:
@T.prim_func
Expand All @@ -228,12 +242,13 @@ def main(A: T.Buffer((25,), "float32")):
for i_s in range(extent):
A[i_s] = T.if_then_else(i_s > 0, A[i_s] + T.float32(1), A[i_s])

mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
with tvm.target.Target(target):
mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)


@pytest.mark.parametrize("extent", (4, T.vscale() * 4))
def test_vectorize_if_then_else_vector(extent):
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_if_then_else_vector(extent, target):
@I.ir_module
class Before:
@T.prim_func
Expand All @@ -251,8 +266,9 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32):
i > 0, A[T.Ramp(i * extent, 1, extent)], T.Broadcast(0, extent)
)

mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
with tvm.target.Target(target):
mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)


def test_vectorize_while_fail():
Expand Down Expand Up @@ -311,9 +327,10 @@ def test_vectorize_dtype_mismatch():


@pytest.mark.parametrize(
"extent, vec_str", [(16, "float32x16"), (T.vscale() * 8, "float32xvscalex8")]
"extent, vec_str, target",
[(16, "float32x16", simple_target), (T.vscale() * 8, "float32xvscalex8", sve_target)],
)
def test_vectorize_with_reinterpret(extent, vec_str):
def test_vectorize_with_reinterpret(extent, vec_str, target):
@I.ir_module
class Before:
@T.prim_func
Expand All @@ -327,11 +344,12 @@ class After:
def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")):
B[T.Ramp(0, 1, extent)] = T.reinterpret(vec_str, A[T.Ramp(0, 1, extent)])

mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
with tvm.target.Target(target):
mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)


@pytest.mark.parametrize("extent", (4, T.vscale() * 4))
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
@pytest.mark.parametrize(
"op",
(
Expand All @@ -352,7 +370,7 @@ def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")):
T.NE,
),
)
def test_vectorize_binary(op, extent):
def test_vectorize_binary(op, extent, target):
@I.ir_module
class Before:
@T.prim_func
Expand All @@ -366,13 +384,14 @@ class After:
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.float32(3), extent), B[T.Ramp(0, 1, extent)])

mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
with tvm.target.Target(target):
mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)


@pytest.mark.parametrize("extent", (4, T.vscale() * 4))
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
@pytest.mark.parametrize("op", (T.And, T.Or))
def test_vectorize_logical(op, extent):
def test_vectorize_logical(op, extent, target):
@I.ir_module
class Before:
@T.prim_func
Expand All @@ -386,12 +405,13 @@ class After:
def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")):
A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.bool(1), extent), B[T.Ramp(0, 1, extent)])

mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
with tvm.target.Target(target):
mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)


@pytest.mark.parametrize("extent", (4, T.vscale() * 4))
def test_vectorize_select(extent):
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_select(extent, target):
@I.ir_module
class Before:
@T.prim_func
Expand All @@ -409,12 +429,16 @@ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
B[T.Ramp(0, 1, extent)],
)

mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
with tvm.target.Target(target):
mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)


@pytest.mark.parametrize("extent, vec_str", [(4, "int32x4"), (T.vscale() * 4, "int32xvscalex4")])
def test_vectorize_cast(extent, vec_str):
@pytest.mark.parametrize(
"extent, vec_str, target",
[(4, "int32x4", simple_target), (T.vscale() * 4, "int32xvscalex4", sve_target)],
)
def test_vectorize_cast(extent, vec_str, target):
@I.ir_module
class Before:
@T.prim_func
Expand All @@ -428,8 +452,9 @@ class After:
def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
A[T.Ramp(0, 1, extent)] = T.Cast(vec_str, B[T.Ramp(0, 1, extent)])

mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
with tvm.target.Target(target):
mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)


def test_illegal_extent():
Expand All @@ -441,10 +466,27 @@ def main(A: T.Buffer((25,), "int32")):
for j in T.vectorized(n):
A[j] = 3

error_msg = f"Invalid expression for scalable lanes n"
error_msg = f"Failed to vectorize loop with extent n for target \\(nullptr\\)"
with pytest.raises(tvm.error.InternalError, match=error_msg):
tvm.tir.transform.VectorizeLoop()(Mod)


def test_illegal_vscale_in_non_sve_compilation():
@I.ir_module
class Mod:
@T.prim_func
def main(A: T.Buffer((16,), "float32")):
for j in T.vectorized(0, 4 * T.vscale()):
A[j] = 13

msg = (
f"Failed to vectorize loop with extent T.vscale\\(\\) \\* 4 for target "
f"llvm -keys=cpu -mtriple=x86_64-linux-gnu"
)
with tvm.target.Target(simple_target):
with pytest.raises(tvm.error.InternalError, match=msg):
tvm.tir.transform.VectorizeLoop()(Mod)


if __name__ == "__main__":
tvm.testing.main()