Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/te/operation/op_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,14 @@ Stmt Substitute(Stmt s, const std::unordered_map<IterVar, PrimExpr>& value_map)
return tir::Substitute(s, init);
}

PrimExpr Substitute(PrimExpr s, const std::unordered_map<IterVar, PrimExpr>& value_map) {
std::unordered_map<const VarNode*, PrimExpr> init;
for (const auto& kv : value_map) {
init[kv.first->var.get()] = kv.second;
}
return tir::Substitute(s, init);
}

IterVarType ForKindToIterVarType(tir::ForKind kind) {
switch (kind) {
case ForKind::kSerial:
Expand Down
10 changes: 9 additions & 1 deletion src/te/operation/op_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ std::vector<Stmt> MakeIfNest(const std::vector<PrimExpr>& predicates);
*/
Stmt ReplaceTensor(Stmt stmt, const std::unordered_map<Tensor, Tensor>& replace);
/*!
* \brief Replace the tensor reference (especially in Call's) in stmt by the replace map.
* \brief Replace the tensor reference (especially in Call's) in primExpr by the replace map.
* \param expr The expression to be processed.
* \param replace The replacement rule.
*/
Expand All @@ -87,6 +87,14 @@ PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map<Tensor, Tensor>&
*/
Stmt Substitute(Stmt stmt, const std::unordered_map<IterVar, PrimExpr>& value_map);

/*!
* \brief Substitute the variables of primExpr by value map.
* \param expr the expression to be processed.
* \param value_map The value map.
* \return Substituted result.
*/
PrimExpr Substitute(PrimExpr expr, const std::unordered_map<IterVar, PrimExpr>& value_map);

/*!
* \brief Converts Halide ForKind to its corresponding IterVarType
* \param kind The ForKind to be converted
Expand Down
6 changes: 4 additions & 2 deletions src/te/operation/tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ Array<PrimExpr> MatchTensorizeBody(const ComputeOpNode* self, const Stage& stage
}

void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage,
const std::unordered_map<IterVar, PrimExpr>& value_map,
const std::unordered_map<IterVar, Range>& dom_map,
const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >& in_region,
Expand All @@ -327,7 +328,8 @@ void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage,

for (size_t i = 0; i < body.size(); ++i) {
PrimExpr lhs = ana.Simplify(body[i]);
PrimExpr rhs = ana.Simplify(intrin_compute->body[i]);
// run substitution because the intrin body could depend on outer loop vars.
PrimExpr rhs = ana.Simplify(Substitute(intrin_compute->body[i], value_map));
if (lhs.dtype() != rhs.dtype()) {
LOG(FATAL) << "Failed to match the data type with TensorIntrin " << intrin->name
<< "'s declaration "
Expand All @@ -349,7 +351,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage,
ICHECK(intrin.defined());
ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map, debug_keep_trivial_loop);
VerifyTensorizeLoopNest(self, stage, n, tloc);
VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin);
VerifyTensorizeBody(self, stage, n.main_vmap, dom_map, out_dom, in_region, intrin);
// Start bind data.
Stmt nop = Evaluate(0);
std::vector<Stmt> input_bind_nest, output_bind_nest;
Expand Down
57 changes: 48 additions & 9 deletions tests/python/unittest/test_te_schedule_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,22 @@
from tvm import te


def intrin_vadd(n):
def intrin_vadd(xo, m, n):
x = te.placeholder((n,), name="vx")
y = te.placeholder((n,), name="vy")
z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
if m % n == 0:
body = lambda i: x[i] + y[i]
else:
body = lambda i: tvm.tir.Select(
xo * n + i < m, x[i] + y[i], tvm.tir.const(0, dtype=x.dtype)
)
z = te.compute(x.shape, body, name="z")

def intrin_func(ins, outs):
xx, yy = ins
zz = outs[0]
# special handle needed to tackle tail loop part when m % n != 0
# here is tvm.min(n, m - xo * n)
return tvm.tir.call_packed("vadd", xx, yy, zz)

buffer_params = {"offset_factor": 16}
Expand Down Expand Up @@ -84,15 +92,17 @@ def intrin_func(ins, outs):


def test_tensorize_vadd():
m = 128
x = te.placeholder((m,), name="x")
y = te.placeholder((m,), name="y")
z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
def add(m):
x = te.placeholder((m,), name="x")
y = te.placeholder((m,), name="y")
z = te.compute(x.shape, lambda i: x[i] + y[i], name="z")
return x, y, z

def check(factor):
def check(m, factor):
x, y, z = add(m)
s = te.create_schedule(z.op)
xo, xi = s[z].split(z.op.axis[0], factor=factor)
vadd = intrin_vadd(factor)
vadd = intrin_vadd(xo, m, factor)
s[z].tensorize(xi, vadd)
s = s.normalize()
dom_map = tvm.te.schedule.InferBound(s)
Expand All @@ -108,7 +118,36 @@ def check(factor):
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [x, y, z])

check(16)
def check_cache_write(m, factor):
x, y, z = add(m)
s = te.create_schedule(z.op)
_, _ = s[z].split(z.op.axis[0], factor=factor)

z_global = s.cache_write(z, "global")
xo, xi = z_global.op.axis

vadd = intrin_vadd(xo, m, factor)
s[z_global].tensorize(xi, vadd)
s = s.normalize()
dom_map = tvm.te.schedule.InferBound(s)
finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
out_dom, in_dom = finfer(s[z_global], dom_map)
# outer loop var will be rebased, so min value is the new loop var and extent is 1
assert tvm.ir.structural_equal(out_dom[xo].extent, 1)
assert isinstance(out_dom[xo].min, tvm.tir.Var)
assert xo.var.name == out_dom[xo].min.name

fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(s[z_global], out_dom, in_dom, vadd)[0]
ana = tvm.arith.Analyzer()
vars = tvm.runtime.convert({xo.var: out_dom[xo].min})
vadd_body = tvm.tir.stmt_functor.substitute(vadd.op.body[0], vars)
assert tvm.ir.structural_equal(ana.simplify(body), ana.simplify(vadd_body))
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
tvm.lower(s, [x, y, z])

check(128, 16)
check_cache_write(129, 16)


def test_tensorize_matmul():
Expand Down