From 83e740e5c7776e95a1e3c809f22665f8f1f1036d Mon Sep 17 00:00:00 2001 From: leeexyz Date: Mon, 22 Feb 2021 21:39:49 +0800 Subject: [PATCH 1/2] [Tensorize] Support conds depend on outer loop vars inside tensorize scope --- src/te/operation/op_utils.cc | 8 +++ src/te/operation/op_utils.h | 10 +++- src/te/operation/tensorize.cc | 6 +- .../unittest/test_te_schedule_tensorize.py | 56 ++++++++++++++++--- 4 files changed, 68 insertions(+), 12 deletions(-) diff --git a/src/te/operation/op_utils.cc b/src/te/operation/op_utils.cc index 32ffccbbec1f..b3897e142545 100644 --- a/src/te/operation/op_utils.cc +++ b/src/te/operation/op_utils.cc @@ -243,6 +243,14 @@ Stmt Substitute(Stmt s, const std::unordered_map& value_map) return tir::Substitute(s, init); } +PrimExpr Substitute(PrimExpr s, const std::unordered_map& value_map) { + std::unordered_map init; + for (const auto& kv : value_map) { + init[kv.first->var.get()] = kv.second; + } + return tir::Substitute(s, init); +} + IterVarType ForKindToIterVarType(tir::ForKind kind) { switch (kind) { case ForKind::kSerial: diff --git a/src/te/operation/op_utils.h b/src/te/operation/op_utils.h index e6bf2caae6e0..02f4a860a01d 100644 --- a/src/te/operation/op_utils.h +++ b/src/te/operation/op_utils.h @@ -73,7 +73,7 @@ std::vector MakeIfNest(const std::vector& predicates); */ Stmt ReplaceTensor(Stmt stmt, const std::unordered_map& replace); /*! - * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map. + * \brief Replace the tensor reference (especially in Call's) in primExpr by the replace map. * \param expr The expression to be processed. * \param replace The replacement rule. */ @@ -87,6 +87,14 @@ PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map& */ Stmt Substitute(Stmt stmt, const std::unordered_map& value_map); +/*! + * \brief Substitute the variables of primExpr by value map. + * \param expr the expression to be processed. + * \param value_map The value map. + * \return Substituted result. + */ +PrimExpr Substitute(PrimExpr expr, const std::unordered_map& value_map); + /*! * \brief Converts Halide ForKind to its corresponding IterVarType * \param kind The ForKind to be converted diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index bfd1ec579818..ea713220eddd 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -311,6 +311,7 @@ Array MatchTensorizeBody(const ComputeOpNode* self, const Stage& stage } void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& value_map, const std::unordered_map& dom_map, const std::unordered_map& out_dom, const std::unordered_map >& in_region, @@ -327,7 +328,8 @@ void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage, for (size_t i = 0; i < body.size(); ++i) { PrimExpr lhs = ana.Simplify(body[i]); - PrimExpr rhs = ana.Simplify(intrin_compute->body[i]); + // run substitution because the intrin body could depend on outer loop vars. + PrimExpr rhs = ana.Simplify(Substitute(intrin_compute->body[i], value_map)); if (lhs.dtype() != rhs.dtype()) { LOG(FATAL) << "Failed to match the data type with TensorIntrin " << intrin->name << "'s declaration " @@ -349,7 +351,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, ICHECK(intrin.defined()); ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map, debug_keep_trivial_loop); VerifyTensorizeLoopNest(self, stage, n, tloc); - VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin); + VerifyTensorizeBody(self, stage, n.main_vmap, dom_map, out_dom, in_region, intrin); // Start bind data. Stmt nop = Evaluate(0); std::vector input_bind_nest, output_bind_nest; diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index 83a5d30bb90d..0960b342fffc 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -18,14 +18,20 @@ from tvm import te -def intrin_vadd(n): +def intrin_vadd(xo, m, n): x = te.placeholder((n,), name="vx") y = te.placeholder((n,), name="vy") - z = te.compute(x.shape, lambda i: x[i] + y[i], name="z") + if m % n == 0: + body = lambda i: x[i] + y[i] + else: + body = lambda i: tvm.tir.Select(xo * n + i < m, x[i] + y[i], tvm.tir.const(0, dtype=x.dtype)) + z = te.compute(x.shape, body, name="z") def intrin_func(ins, outs): xx, yy = ins zz = outs[0] + # special handle needed to tackle tail loop part when m % n != 0 + # here is tvm.min(n, m - xo * n) return tvm.tir.call_packed("vadd", xx, yy, zz) buffer_params = {"offset_factor": 16} @@ -84,15 +90,18 @@ def intrin_func(ins, outs): def test_tensorize_vadd(): - m = 128 - x = te.placeholder((m,), name="x") - y = te.placeholder((m,), name="y") - z = te.compute(x.shape, lambda i: x[i] + y[i], name="z") - def check(factor): + def add(m): + x = te.placeholder((m,), name="x") + y = te.placeholder((m,), name="y") + z = te.compute(x.shape, lambda i: x[i] + y[i], name="z") + return x, y, z + + def check(m, factor): + x, y, z = add(m) s = te.create_schedule(z.op) xo, xi = s[z].split(z.op.axis[0], factor=factor) - vadd = intrin_vadd(factor) + vadd = intrin_vadd(xo, m, factor) s[z].tensorize(xi, vadd) s = s.normalize() dom_map = tvm.te.schedule.InferBound(s) @@ -108,7 +117,36 @@ def check(factor): stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [x, y, z]) - check(16) + def check_cache_write(m, factor): + x, y, z = add(m) + s = te.create_schedule(z.op) + _, _ = s[z].split(z.op.axis[0], factor=factor) + + z_global = s.cache_write(z, "global") + xo, xi = z_global.op.axis + + vadd = intrin_vadd(xo, m, factor) + s[z_global].tensorize(xi, vadd) + s = s.normalize() + dom_map = tvm.te.schedule.InferBound(s) + finfer = tvm.get_global_func("test.op.InferTensorizeRegion") + out_dom, in_dom = finfer(s[z_global], dom_map) + # outer loop var will be rebased, so min value is the new loop var and extent is 1 + assert tvm.ir.structural_equal(out_dom[xo].extent, 1) + assert isinstance(out_dom[xo].min, tvm.tir.Var) + assert xo.var.name == out_dom[xo].min.name + + fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") + body = fmatch(s[z_global], out_dom, in_dom, vadd)[0] + ana = tvm.arith.Analyzer() + vars = tvm.runtime.convert({xo.var: out_dom[xo].min}) + vadd_body = tvm.tir.stmt_functor.substitute(vadd.op.body[0], vars) + assert tvm.ir.structural_equal(ana.simplify(body), ana.simplify(vadd_body)) + stmt = tvm.te.schedule.ScheduleOps(s, dom_map) + tvm.lower(s, [x, y, z]) + + check(128, 16) + check_cache_write(129, 16) def test_tensorize_matmul(): From 47a8b9480b2a1b494f7c4676a52a651e90419945 Mon Sep 17 00:00:00 2001 From: leeexyz Date: Thu, 25 Feb 2021 00:05:48 +0800 Subject: [PATCH 2/2] Reformat --- tests/python/unittest/test_te_schedule_tensorize.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index 0960b342fffc..fdafdb74fc0b 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -24,7 +24,9 @@ def intrin_vadd(xo, m, n): if m % n == 0: body = lambda i: x[i] + y[i] else: - body = lambda i: tvm.tir.Select(xo * n + i < m, x[i] + y[i], tvm.tir.const(0, dtype=x.dtype)) + body = lambda i: tvm.tir.Select( + xo * n + i < m, x[i] + y[i], tvm.tir.const(0, dtype=x.dtype) + ) z = te.compute(x.shape, body, name="z") def intrin_func(ins, outs): @@ -90,7 +92,6 @@ def intrin_func(ins, outs): def test_tensorize_vadd(): - def add(m): x = te.placeholder((m,), name="x") y = te.placeholder((m,), name="y")