From 7e87e88189decb209f64c0f283086cc9aec4a726 Mon Sep 17 00:00:00 2001 From: llehtahw Date: Mon, 19 Apr 2021 16:11:34 +0800 Subject: [PATCH 1/3] Fix --- src/te/operation/tensorize.cc | 3 +++ .../unittest/test_te_schedule_tensorize.py | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index ea713220eddd..4cc7f1c3bf1b 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -327,6 +327,9 @@ void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage, ana.Bind(compute_intrin_iter_space); for (size_t i = 0; i < body.size(); ++i) { + if (self->body[i].same_as(intrin_compute->body[i])) { + continue; + } PrimExpr lhs = ana.Simplify(body[i]); // run substitution because the intrin body could depend on outer loop vars. PrimExpr rhs = ana.Simplify(Substitute(intrin_compute->body[i], value_map)); diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index fdafdb74fc0b..dc37326eaee5 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -369,8 +369,28 @@ def intrin_func(ins, outs): assert stmt.body.body.loop_var.name == C.op.axis[0].var.name +def test_tensorize_reuse_compute(): + def get_compute_args(): + l = 2 + a = tvm.te.placeholder([l], name="a") + b = tvm.te.placeholder([l], name="b") + return a, b, tvm.te.compute([l], lambda i: a[i] + b[i]) + + a, b, c = get_compute_args() + + def get_intrin(): + def _intrin_func(ins, outs): + return tvm.tir.call_packed("fakeadd", ins[0], ins[1], outs[0]) + return tvm.te.decl_tensor_intrin(c.op, _intrin_func) + + s = tvm.te.create_schedule([c.op]) + s[c].tensorize(c.op.axis[0], get_intrin()) + tvm.lower(s, (a, b, c)) + + if __name__ == "__main__": test_tensorize_vadd() test_tensorize_matmul() test_tensorize_op() test_tensorize_tensor_compute_op() + test_tensorize_reuse_compute() From 74ae222b39b41a5c091bdbada04de15876d00fc4 Mon Sep 17 00:00:00 2001 From: llehtahw Date: Mon, 19 Apr 2021 17:23:59 +0800 Subject: [PATCH 2/3] Fix again --- src/te/operation/tensorize.cc | 13 ++++++++++++- tests/python/unittest/test_te_schedule_tensorize.py | 1 + 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index 4cc7f1c3bf1b..06d50c794bbc 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -184,6 +184,8 @@ class TensorIntrinMatcher final : public StmtExprMutator { PrimExpr VisitExpr_(const ReduceNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); + if (expr.same_as(GetRef(op))) + return expr; op = expr.as(); Array axis; for (size_t i = 0; i < op->axis.size(); ++i) { @@ -216,6 +218,9 @@ class TensorIntrinMatcher final : public StmtExprMutator { Array inputs = self->InputTensors(); ICHECK_EQ(inputs.size(), intrin->inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { + if (inputs[i].same_as(intrin->inputs[i])) { + continue; + } InputEntry e; e.tensor = intrin->inputs[i]; e.region = Array(in_region.at(inputs[i])); @@ -251,6 +256,9 @@ class TensorIntrinMatcher final : public StmtExprMutator { for (size_t i = axis_start; i < self->axis.size(); ++i) { IterVar iv = self->axis[i]; IterVar target_iv = intrin_compute->axis[i - axis_start]; + if (iv.same_as(target_iv)) { + continue; + } Range r = out_dom.at(iv); var_remap_[iv->var.get()] = target_iv->var + r->min; axis_remap_[iv] = target_iv; @@ -270,6 +278,9 @@ class TensorIntrinMatcher final : public StmtExprMutator { for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) { IterVar iv = self->reduce_axis[i]; IterVar target_iv = intrin_compute->reduce_axis[i - axis_start]; + if (iv.same_as(target_iv)) { + continue; + } Range r = out_dom.at(iv); var_remap_[iv->var.get()] = target_iv->var + r->min; axis_remap_[iv] = target_iv; @@ -327,7 +338,7 @@ void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage, ana.Bind(compute_intrin_iter_space); for (size_t i = 0; i < body.size(); ++i) { - if (self->body[i].same_as(intrin_compute->body[i])) { + if (body[i].same_as(intrin_compute->body[i])) { continue; } PrimExpr lhs = ana.Simplify(body[i]); diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index dc37326eaee5..7ec2066454c3 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -381,6 +381,7 @@ def get_compute_args(): def get_intrin(): def _intrin_func(ins, outs): return tvm.tir.call_packed("fakeadd", ins[0], ins[1], outs[0]) + return tvm.te.decl_tensor_intrin(c.op, _intrin_func) s = tvm.te.create_schedule([c.op]) From 29f714e45f10ec4aabc7682d06006daaf18a36ef Mon Sep 17 00:00:00 2001 From: llehtahw Date: Mon, 19 Apr 2021 18:19:27 +0800 Subject: [PATCH 3/3] Format --- src/te/operation/tensorize.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index 06d50c794bbc..4b236f6e2dc2 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -184,8 +184,7 @@ class TensorIntrinMatcher final : public StmtExprMutator { PrimExpr VisitExpr_(const ReduceNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); - if (expr.same_as(GetRef(op))) - return expr; + if (expr.same_as(GetRef(op))) return expr; op = expr.as(); Array axis; for (size_t i = 0; i < op->axis.size(); ++i) {