diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index ea713220eddd..4b236f6e2dc2 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -184,6 +184,7 @@ class TensorIntrinMatcher final : public StmtExprMutator { PrimExpr VisitExpr_(const ReduceNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); + if (expr.same_as(GetRef(op))) return expr; op = expr.as(); Array axis; for (size_t i = 0; i < op->axis.size(); ++i) { @@ -216,6 +217,9 @@ class TensorIntrinMatcher final : public StmtExprMutator { Array inputs = self->InputTensors(); ICHECK_EQ(inputs.size(), intrin->inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { + if (inputs[i].same_as(intrin->inputs[i])) { + continue; + } InputEntry e; e.tensor = intrin->inputs[i]; e.region = Array(in_region.at(inputs[i])); @@ -251,6 +255,9 @@ class TensorIntrinMatcher final : public StmtExprMutator { for (size_t i = axis_start; i < self->axis.size(); ++i) { IterVar iv = self->axis[i]; IterVar target_iv = intrin_compute->axis[i - axis_start]; + if (iv.same_as(target_iv)) { + continue; + } Range r = out_dom.at(iv); var_remap_[iv->var.get()] = target_iv->var + r->min; axis_remap_[iv] = target_iv; @@ -270,6 +277,9 @@ class TensorIntrinMatcher final : public StmtExprMutator { for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) { IterVar iv = self->reduce_axis[i]; IterVar target_iv = intrin_compute->reduce_axis[i - axis_start]; + if (iv.same_as(target_iv)) { + continue; + } Range r = out_dom.at(iv); var_remap_[iv->var.get()] = target_iv->var + r->min; axis_remap_[iv] = target_iv; @@ -327,6 +337,9 @@ void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage, ana.Bind(compute_intrin_iter_space); for (size_t i = 0; i < body.size(); ++i) { + if (body[i].same_as(intrin_compute->body[i])) { + continue; + } PrimExpr lhs = ana.Simplify(body[i]); // run substitution because the intrin body could depend on outer loop vars. PrimExpr rhs = ana.Simplify(Substitute(intrin_compute->body[i], value_map)); diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index fdafdb74fc0b..7ec2066454c3 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -369,8 +369,29 @@ def intrin_func(ins, outs): assert stmt.body.body.loop_var.name == C.op.axis[0].var.name +def test_tensorize_reuse_compute(): + def get_compute_args(): + l = 2 + a = tvm.te.placeholder([l], name="a") + b = tvm.te.placeholder([l], name="b") + return a, b, tvm.te.compute([l], lambda i: a[i] + b[i]) + + a, b, c = get_compute_args() + + def get_intrin(): + def _intrin_func(ins, outs): + return tvm.tir.call_packed("fakeadd", ins[0], ins[1], outs[0]) + + return tvm.te.decl_tensor_intrin(c.op, _intrin_func) + + s = tvm.te.create_schedule([c.op]) + s[c].tensorize(c.op.axis[0], get_intrin()) + tvm.lower(s, (a, b, c)) + + if __name__ == "__main__": test_tensorize_vadd() test_tensorize_matmul() test_tensorize_op() test_tensorize_tensor_compute_op() + test_tensorize_reuse_compute()