diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index f7a0f2a3b61c..4d821c2c4236 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -546,6 +546,35 @@ class StmtExprMutator : } }; +/*! + * \brief recursively visit the ir in post DFS order node, and transform it + * + * \param node The ir to be transformed. + * \param preorder The function called in before recursive mutation + * If preorder returns None, then the transform will proceed to recursive call. + * If preorder returns a not None Stmt/Expr, the transformer will simply return it and + * won't do further recursion. + * \param postorder The function called after recursive mutation. + * The recursive mutation result is passed to postorder for further mutation. + * \param only_enable List of StringImm. + * If it is empty, all IRNode will call preorder/postorder + * If it is not empty, preorder/postorder will only be called + * when the IRNode's type key is in the list. + */ +TVM_DLL Stmt IRTransform(Stmt node, + const runtime::PackedFunc& preorder, + const runtime::PackedFunc& postorder, + const Array& only_enable = {}); + +/*! + * \brief recursively visit the ir in post DFS order node, apply fvisit + * Each node is guaranteed to be visited only once. + * \param node The ir to be visited. + * \param fvisit The visitor function to be applied. + */ +TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function fvisit); + + } // namespace ir } // namespace tvm #endif // TVM_IR_FUNCTOR_EXT_H_ diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index 702cea3ce8fd..14769586a959 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -122,27 +122,6 @@ class TVM_DLL IRMutator { virtual Expr Mutate_(const StringImm* op, const Expr& e); virtual Expr Mutate_(const Shuffle* op, const Expr& e); }; - - -/*! - * \brief recursively visit the ir in post DFS order node, and transform it - * - * \param node The ir to be transformed. - * \param preorder The function called in before recursive mutation - * If preorder returns None, then the transform will proceed to recursive call. - * If preorder returns a not None Stmt/Expr, the transformer will simply return it and - * won't do further recursion. - * \param postorder The function called after recursive mutation. - * The recursive mutation result is passed to postorder for further mutation. - * \param only_enable List of StringImm. - * If it is empty, all IRNode will call preorder/postorder - * If it is not empty, preorder/postorder will only be called - * when the IRNode's type key is in the list. - */ -Stmt IRTransform(Stmt node, - const runtime::PackedFunc& preorder, - const runtime::PackedFunc& postorder, - const Array& only_enable = {}); } // namespace ir } // namespace tvm #endif // TVM_IR_MUTATOR_H_ diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index cffcdcbdf5b8..e6bd3c6f344d 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -145,15 +145,6 @@ class TVM_DLL IRVisitor { virtual void Visit_(const FloatImm* op); virtual void Visit_(const StringImm* op); }; - -/*! - * \brief recursively visit the ir in post DFS order node, apply fvisit - * Each node is guaranteed to be visited only once. - * \param node The ir to be visited. - * \param fvisit The visitor function to be applied. - */ -TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function fvisit); - } // namespace ir } // namespace tvm diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 339b25a51894..f1d97f49783d 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -25,8 +25,7 @@ #include #include #include -#include -#include +#include #include namespace tvm { diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 0b84be291f71..6f98017e6a69 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include @@ -38,17 +38,17 @@ using namespace ir; // a visitor to find the path to the target variable // from a expression. -class VariablePathFinder: public IRVisitor { +class VariablePathFinder: public ExprVisitor { public: explicit VariablePathFinder(Expr target) : target_(target) {} - void Visit(const ObjectRef& node) final { + void VisitExpr(const Expr& node) final { if (visited_.count(node.get()) != 0) return; visited_.insert(node.get()); if (!found_) path_.push_back(node.get()); if (node.same_as(target_)) found_ = true; - IRVisitor::Visit(node); + ExprVisitor::VisitExpr(node); if (!found_) path_.pop_back(); } @@ -64,14 +64,14 @@ class VariablePathFinder: public IRVisitor { // return empty vector to represent failure std::vector GetPath(Expr target, Expr expr) { VariablePathFinder v(target); - v.Visit(expr); + v(expr); return v.path_; } enum CompareOp {kGreater, kLess, kEqual}; // a visitor to deduce the bound of a variable from a expression -class BoundDeducer: public IRVisitor { +class BoundDeducer: public ExprVisitor { public: friend class BoundDeduceInputChecker; friend class Converter; @@ -82,39 +82,39 @@ class BoundDeducer: public IRVisitor { void Deduce(); - void Visit(const ObjectRef& e) final { + void VisitExpr(const Expr& e) final { if (!success_) return; if (e.get() == path_[iter_++]) { - IRVisitor::Visit(e); + ExprVisitor::VisitExpr(e); } else { success_ = false; return; } } - void Visit_(const LT* op) final { + void VisitExpr_(const LT* op) final { LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } - void Visit_(const LE* op) final { + void VisitExpr_(const LE* op) final { LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } - void Visit_(const GT* op) final { + void VisitExpr_(const GT* op) final { LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } - void Visit_(const GE* op) final { + void VisitExpr_(const GE* op) final { LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } - void Visit_(const Add* op) final { + void VisitExpr_(const Add* op) final { bool left = op->a.get() == path_[iter_]; result_ -= left ? op->b : op->a; - Visit(left ? op->a : op->b); + this->VisitExpr(left ? op->a : op->b); } - void Visit_(const Sub* op) final { + void VisitExpr_(const Sub* op) final { bool left = op->a.get() == path_[iter_]; if (left) { result_ += op->b; @@ -123,10 +123,10 @@ class BoundDeducer: public IRVisitor { result_ = - result_; comp_op = ReverseOp(comp_op); } - Visit(left ? op->a : op->b); + this->VisitExpr(left ? op->a : op->b); } - void Visit_(const Mul* op) final { + void VisitExpr_(const Mul* op) final { bool left = op->a.get() == path_[iter_]; Expr operand = left ? op->b : op->a; Expr target_var = left ? op->a : op->b; @@ -171,7 +171,7 @@ class BoundDeducer: public IRVisitor { // ( x <= -3/-2 --> x <= 1) } } - Visit(left ? op->a : op->b); + this->VisitExpr(left ? op->a : op->b); } Expr result_; @@ -194,17 +194,17 @@ class BoundDeducer: public IRVisitor { Analyzer analyzer_; }; -class BoundDeduceInputChecker: public IRVisitor { +class BoundDeduceInputChecker: public ExprVisitor { public: bool Check(BoundDeducer* deducer) { deducer_ = deducer; - Visit(deducer_->expr_); + this->VisitExpr(deducer_->expr_); return target_count == 1; } - void Visit(const ObjectRef& e) final { + void VisitExpr(const Expr& e) final { if (e.same_as(deducer_->target_)) ++target_count; - IRVisitor::Visit(e); + ExprVisitor::VisitExpr(e); } private: @@ -305,7 +305,7 @@ void BoundDeducer::Deduce() { } expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); - Visit(expr_); + this->VisitExpr(expr_); } void BoundDeducer::Relax() { diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 6a19a7aeb3f2..d05ee2dd9a30 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -23,7 +23,6 @@ */ #include #include -#include #include "const_fold.h" #include "pattern_match.h" #include "rewrite_simplify.h" @@ -435,30 +434,30 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { Expr CanonicalSimplify(Expr expr) { - expr = Mutate(expr); + expr = operator()(expr); return expr; } // override the original mutate function. - Expr Mutate(Expr expr) final { - expr = IRMutator::Mutate(expr); + Expr VisitExpr(const Expr& input_expr) final { + auto expr = Rewriter::VisitExpr(input_expr); return Normalize(expr); } // Normal mutation without normalization. Expr CanonicalMutate(Expr expr) { - return IRMutator::Mutate(expr); + return Rewriter::VisitExpr(expr); } - using Rewriter::Mutate_; - Expr Mutate_(const Add* op, const Expr& self) final; - Expr Mutate_(const Sub* op, const Expr& self) final; - Expr Mutate_(const Mul* op, const Expr& self) final; - Expr Mutate_(const Div* op, const Expr& self) final; - Expr Mutate_(const Mod* op, const Expr& self) final; - Expr Mutate_(const FloorDiv* op, const Expr& self) final; - Expr Mutate_(const FloorMod* op, const Expr& self) final; - Expr Mutate_(const Reduce* op, const Expr& self) final; + using Rewriter::VisitExpr_; + Expr VisitExpr_(const Add* op) final; + Expr VisitExpr_(const Sub* op) final; + Expr VisitExpr_(const Mul* op) final; + Expr VisitExpr_(const Div* op) final; + Expr VisitExpr_(const Mod* op) final; + Expr VisitExpr_(const FloorDiv* op) final; + Expr VisitExpr_(const FloorMod* op) final; + Expr VisitExpr_(const Reduce* op) final; private: /*! @@ -567,9 +566,9 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { }; Expr CanonicalSimplifier::Impl:: -Mutate_(const Add* op, const Expr& self) { +VisitExpr_(const Add* op) { if (!IsIndexType(op->dtype)) { - return Rewriter::Mutate_(op, self); + return Rewriter::VisitExpr_(op); } // normalize Expr a = this->CanonicalMutate(op->a); @@ -593,9 +592,9 @@ Mutate_(const Add* op, const Expr& self) { } Expr CanonicalSimplifier::Impl:: -Mutate_(const Sub* op, const Expr& self) { +VisitExpr_(const Sub* op) { if (!IsIndexType(op->dtype)) { - return Rewriter::Mutate_(op, self); + return Rewriter::VisitExpr_(op); } // normalize Expr a = this->CanonicalMutate(op->a); @@ -620,9 +619,9 @@ Mutate_(const Sub* op, const Expr& self) { Expr CanonicalSimplifier::Impl:: -Mutate_(const Mul* op, const Expr& self) { +VisitExpr_(const Mul* op) { if (!IsIndexType(op->dtype)) { - return Rewriter::Mutate_(op, self); + return Rewriter::VisitExpr_(op); } // normalize Expr a = this->CanonicalMutate(op->a); @@ -652,7 +651,7 @@ Mutate_(const Mul* op, const Expr& self) { a = Normalize(a); b = Normalize(b); if (op->a.same_as(a) && op->b.same_as(b)) { - return self; + return GetRef(op); } else { return Mul::make(a, b); } @@ -727,9 +726,9 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { } Expr CanonicalSimplifier::Impl:: -Mutate_(const Div* op, const Expr& self) { +VisitExpr_(const Div* op) { if (!IsIndexType(op->dtype)) { - return Rewriter::Mutate_(op, self); + return Rewriter::VisitExpr_(op); } Expr a = this->CanonicalMutate(op->a); @@ -781,16 +780,16 @@ Mutate_(const Div* op, const Expr& self) { a = Normalize(a); b = Normalize(b); if (op->a.same_as(a) && op->b.same_as(b)) { - return self; + return GetRef(op); } else { return Div::make(a, b); } } Expr CanonicalSimplifier::Impl:: -Mutate_(const FloorDiv* op, const Expr& self) { +VisitExpr_(const FloorDiv* op) { if (!IsIndexType(op->dtype)) { - return Rewriter::Mutate_(op, self); + return Rewriter::VisitExpr_(op); } Expr a = this->CanonicalMutate(op->a); Expr b = this->CanonicalMutate(op->b); @@ -837,7 +836,7 @@ Mutate_(const FloorDiv* op, const Expr& self) { a = Normalize(a); b = Normalize(b); if (op->a.same_as(a) && op->b.same_as(b)) { - return self; + return GetRef(op); } else { return FloorDiv::make(a, b); } @@ -866,7 +865,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { // Do a recursive call to simplify the mod with the new factor. if (new_upper_factor < lhs->upper_factor && lhs->upper_factor != SplitExprNode::kPosInf) { - auto updated = ToSplitExpr(Mutate(ModImpl( + auto updated = ToSplitExpr(this->VisitExpr(ModImpl( lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode))); // re-apply the lower_factor if (lhs->lower_factor != 1) { @@ -894,9 +893,9 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { } Expr CanonicalSimplifier::Impl:: -Mutate_(const Mod* op, const Expr& self) { +VisitExpr_(const Mod* op) { if (!IsIndexType(op->dtype)) { - return Rewriter::Mutate_(op, self); + return Rewriter::VisitExpr_(op); } // normalize Expr a = this->CanonicalMutate(op->a); @@ -957,16 +956,16 @@ Mutate_(const Mod* op, const Expr& self) { a = Normalize(a); b = Normalize(b); if (op->a.same_as(a) && op->b.same_as(b)) { - return self; + return GetRef(op); } else { return Mod::make(a, b); } } Expr CanonicalSimplifier::Impl:: -Mutate_(const FloorMod* op, const Expr& self) { +VisitExpr_(const FloorMod* op) { if (!IsIndexType(op->dtype)) { - return Rewriter::Mutate_(op, self); + return Rewriter::VisitExpr_(op); } // normalize Expr a = this->CanonicalMutate(op->a); @@ -1017,7 +1016,7 @@ Mutate_(const FloorMod* op, const Expr& self) { a = Normalize(a); b = Normalize(b); if (op->a.same_as(a) && op->b.same_as(b)) { - return self; + return GetRef(op); } else { return FloorMod::make(a, b); } @@ -1029,7 +1028,7 @@ SimplifyReduceCombiner(const Reduce* op) { // First simplify the results Array simplified_result; for (const auto& res : op->combiner->result) { - Expr new_res = Mutate(res); + Expr new_res = this->VisitExpr(res); simplified_result.push_back(new_res); } @@ -1078,7 +1077,7 @@ SimplifyReduceCombiner(const Reduce* op) { if (used[i]) { // We simplify the result and identity, but not the source new_result.push_back(simplified_result[i]); - new_identity.push_back(Mutate(op->combiner->identity_element[i])); + new_identity.push_back(this->VisitExpr(op->combiner->identity_element[i])); new_lhs.push_back(op->combiner->lhs[i]); new_rhs.push_back(op->combiner->rhs[i]); new_source.push_back(op->source[i]); @@ -1095,9 +1094,9 @@ SimplifyReduceCombiner(const Reduce* op) { } Expr CanonicalSimplifier::Impl:: -Mutate_(const Reduce* op, const Expr& self) { +VisitExpr_(const Reduce* op) { // Recursively call simplification when necessary. - Expr ret = RewriteSimplifier::Impl::Mutate_(op, self); + Expr ret = RewriteSimplifier::Impl::VisitExpr_(op); op = ret.as(); // already been simplified by const reduction axis removal if (op == nullptr) return ret; @@ -1106,7 +1105,7 @@ Mutate_(const Reduce* op, const Expr& self) { // assumption we would have to perform a single iteration of the loop, i.e. use // `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]` // instead of `op->source[op->value_index]`. The former may be more difficult to simplify. - return Mutate( + return this->VisitExpr( Select::make(op->condition, op->source[op->value_index], op->combiner->identity_element[op->value_index])); diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index 93bf708a113f..8b4ea2fa8133 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -25,7 +25,6 @@ #define TVM_ARITHMETIC_CONST_FOLD_H_ #include -#include #include #include #include diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index c4ee40f12da8..b8ec974b436c 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -23,7 +23,6 @@ */ #include #include -#include #include #include diff --git a/src/arithmetic/domain_touched.cc b/src/arithmetic/domain_touched.cc index 947f0050c6cb..bdd5daa6265c 100644 --- a/src/arithmetic/domain_touched.cc +++ b/src/arithmetic/domain_touched.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include @@ -36,13 +36,13 @@ namespace arith { using namespace ir; // Find Read region of the tensor in the stmt. -class FuncTouchedDomain final : public IRVisitor { +class FuncTouchedDomain final : public StmtExprVisitor { public: FuncTouchedDomain(const Tensor &tensor, bool consider_calls, bool consider_provides) : tensor_(tensor), consider_calls_(consider_calls), consider_provides_(consider_provides) {} Domain Find(const Stmt& stmt) { - this->Visit(stmt); + operator()(stmt); Domain ret; Range none; for (size_t i = 0; i < bounds_.size(); ++i) { @@ -51,49 +51,49 @@ class FuncTouchedDomain final : public IRVisitor { return ret; } - void Visit_(const For *op) final { + void VisitStmt_(const For *op) final { const Variable* var = op->loop_var.get(); dom_map_[var] = IntSet::range( Range::make_by_min_extent(op->min, op->extent)); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); dom_map_.erase(var); } - void Visit_(const LetStmt* op) final { + void VisitStmt_(const LetStmt* op) final { dom_map_[op->var.get()] = arith::EvalSet(op->value, dom_map_); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); dom_map_.erase(op->var.get()); } /* TODO: Thread extent unitest not generated.*/ - void Visit_(const AttrStmt* op) final { + void VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::thread_extent) { const IterVarNode* thread_axis = op->node.as(); CHECK(thread_axis); const Variable* var = thread_axis->var.get(); dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value)); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); dom_map_.erase(var); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } } - void Visit_(const Call* op) final { + void VisitExpr_(const Call* op) final { if (consider_calls_ && tensor_->op.same_as(op->func) && tensor_->value_index == op->value_index) { Touch(op->args); } - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); } - void Visit_(const Provide* op) final { + void VisitStmt_(const Provide* op) final { if (consider_provides_ && tensor_->op.same_as(op->func) && tensor_->value_index == op->value_index) { Touch(op->args); } - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } private: diff --git a/src/arithmetic/ir_mutator_with_analyzer.cc b/src/arithmetic/ir_mutator_with_analyzer.cc index 0d4b8f26b18b..bfce2c26fbe3 100644 --- a/src/arithmetic/ir_mutator_with_analyzer.cc +++ b/src/arithmetic/ir_mutator_with_analyzer.cc @@ -30,41 +30,44 @@ namespace arith { using namespace ir; Stmt IRMutatorWithAnalyzer:: -Mutate_(const For* op, const Stmt& s) { +VisitStmt_(const For* op) { analyzer_->Bind(op->loop_var, - Range::make_by_min_extent(op->min, op->extent)); - return IRMutator::Mutate_(op, s); + Range::make_by_min_extent(op->min, op->extent)); + return StmtExprMutator::VisitStmt_(op); } Stmt IRMutatorWithAnalyzer:: -Mutate_(const LetStmt* op, const Stmt& s) { - Expr value = this->Mutate(op->value); +VisitStmt_(const LetStmt* op) { + Expr value = this->VisitExpr(op->value); if (!ir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); } // We keep the let-binding here // as sub-class may or maynot choose to replace it. - Stmt body = this->Mutate(op->body); + Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return s; + return GetRef(op); } else { - return LetStmt::make(op->var, value, body); + auto n = this->CopyOnWrite(op); + n->value = std::move(value); + n->body = std::move(body); + return Stmt(n); } } Stmt IRMutatorWithAnalyzer:: -Mutate_(const IfThenElse* op, const Stmt& s) { - Expr condition = this->Mutate(op->condition); +VisitStmt_(const IfThenElse* op) { + Expr condition = this->VisitExpr(op->condition); Stmt then_case, else_case; { With ctx(analyzer_, condition); - then_case = this->Mutate(op->then_case); + then_case = this->VisitStmt(op->then_case); } if (op->else_case.defined()) { With ctx(analyzer_, analyzer_->rewrite_simplify(Not::make(condition))); - else_case = this->Mutate(op->else_case); + else_case = this->VisitStmt(op->else_case); } if (is_one(condition)) return then_case; if (is_zero(condition)) { @@ -77,57 +80,65 @@ Mutate_(const IfThenElse* op, const Stmt& s) { if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { - return s; + return GetRef(op); } else { - return IfThenElse::make(condition, then_case, else_case); + auto n = this->CopyOnWrite(op); + n->condition = std::move(condition); + n->then_case = std::move(then_case); + n->else_case = std::move(else_case); + return Stmt(n); } } Stmt IRMutatorWithAnalyzer:: -Mutate_(const AttrStmt* op, const Stmt& s) { +VisitStmt_(const AttrStmt* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); return stmt; } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } Stmt IRMutatorWithAnalyzer:: -Mutate_(const AssertStmt* op, const Stmt& s) { - Expr condition = this->Mutate(op->condition); - Expr message = this->Mutate(op->message); +VisitStmt_(const AssertStmt* op) { + Expr condition = this->VisitExpr(op->condition); + Expr message = this->VisitExpr(op->message); With ctx(analyzer_, condition); - Stmt body = this->Mutate(op->body); + Stmt body = this->VisitStmt(op->body); if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) { - return s; + return GetRef(op); } else { - return AssertStmt::make(condition, message, body); + auto n = this->CopyOnWrite(op); + n->condition = std::move(condition); + n->message = std::move(message); + n->body = std::move(body); + return Stmt(n); } } Expr IRMutatorWithAnalyzer:: -Mutate_(const Call* op, const Expr& self) { +VisitExpr_(const Call* op) { // add condition context to if_then_else if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) { - Expr cond = Mutate(op->args[0]); + Expr cond = this->VisitExpr(op->args[0]); Expr true_value, false_value; { With constraint(analyzer_, cond); - true_value = Mutate(op->args[1]); + true_value = this->VisitExpr(op->args[1]); } { With constraint(analyzer_, analyzer_->rewrite_simplify(Not::make(cond))); - false_value = Mutate(op->args[2]); + false_value = this->VisitExpr(op->args[2]); } if (is_zero(cond)) { return false_value; @@ -138,45 +149,45 @@ Mutate_(const Call* op, const Expr& self) { if (cond.same_as(op->args[0]) && true_value.same_as(op->args[1]) && false_value.same_as(op->args[2])) { - return self; + return GetRef(op); } else { return Call::make(op->dtype, op->name, {cond, true_value, false_value}, op->call_type); } } - return IRMutator::Mutate_(op, self); + return StmtExprMutator::VisitExpr_(op); } Expr IRMutatorWithAnalyzer:: -Mutate_(const Let* op, const Expr& self) { - Expr value = this->Mutate(op->value); +VisitExpr_(const Let* op) { + Expr value = this->VisitExpr(op->value); if (!ir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); } // We keep the let-binding here // as sub-class may or maynot choose to replace it. - Expr body = this->Mutate(op->body); + Expr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return self; + return GetRef(op); } else { return Let::make(op->var, value, body); } } Expr IRMutatorWithAnalyzer:: -Mutate_(const Select* op, const Expr& self) { - Expr cond = Mutate(op->condition); +VisitExpr_(const Select* op) { + Expr cond = this->VisitExpr(op->condition); Expr true_value, false_value; { With constraint(analyzer_, cond); - true_value = Mutate(op->true_value); + true_value = VisitExpr(op->true_value); } { With constraint(analyzer_, analyzer_->rewrite_simplify(Not::make(cond))); - false_value = Mutate(op->false_value); + false_value = VisitExpr(op->false_value); } if (is_zero(cond)) { return false_value; @@ -188,20 +199,20 @@ Mutate_(const Select* op, const Expr& self) { if (cond.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { - return self; + return GetRef(op); } else { return Select::make(cond, true_value, false_value); } } Expr IRMutatorWithAnalyzer:: -Mutate_(const Reduce* op, const Expr& self) { +VisitExpr_(const Reduce* op) { // Setup the domain information before simplification. for (const IterVar& iv : op->axis) { analyzer_->Bind(iv->var, iv->dom); } // Recursively call simplification when necessary. - return IRMutator::Mutate_(op, self); + return StmtExprMutator::VisitExpr_(op); } } // namespace arith diff --git a/src/arithmetic/ir_mutator_with_analyzer.h b/src/arithmetic/ir_mutator_with_analyzer.h index bf4118e9c698..9e3a86bb5280 100644 --- a/src/arithmetic/ir_mutator_with_analyzer.h +++ b/src/arithmetic/ir_mutator_with_analyzer.h @@ -24,9 +24,9 @@ #ifndef TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_ #define TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_ -#include +#include #include - +#include namespace tvm { namespace arith { @@ -40,23 +40,24 @@ namespace arith { * * \sa src/arithmetic/ir_mutator_with_analyzer.cc */ -class IRMutatorWithAnalyzer : public ir::IRMutator { +class IRMutatorWithAnalyzer : public ir::StmtExprMutator { public: explicit IRMutatorWithAnalyzer(Analyzer* analyzer) : analyzer_(analyzer) {} - using IRMutator::Mutate_; + using StmtExprMutator::VisitStmt_; + using StmtExprMutator::VisitExpr_; // override functions that need to populate the context information. - Stmt Mutate_(const ir::For* op, const Stmt& self) override; - Stmt Mutate_(const ir::LetStmt* op, const Stmt& self) override; - Stmt Mutate_(const ir::IfThenElse* op, const Stmt& self) override; - Stmt Mutate_(const ir::AttrStmt* op, const Stmt& self) override; - Stmt Mutate_(const ir::AssertStmt* op, const Stmt& self) override; - Expr Mutate_(const ir::Let* op, const Expr& self) override; - Expr Mutate_(const ir::Select* op, const Expr& self) override; - Expr Mutate_(const ir::Call* op, const Expr& self) override; - Expr Mutate_(const ir::Reduce* op, const Expr& self) override; + Stmt VisitStmt_(const ir::For* op) override; + Stmt VisitStmt_(const ir::LetStmt* op) override; + Stmt VisitStmt_(const ir::IfThenElse* op) override; + Stmt VisitStmt_(const ir::AttrStmt* op) override; + Stmt VisitStmt_(const ir::AssertStmt* op) override; + Expr VisitExpr_(const ir::Let* op) override; + Expr VisitExpr_(const ir::Select* op) override; + Expr VisitExpr_(const ir::Call* op) override; + Expr VisitExpr_(const ir::Reduce* op) override; protected: /*! \brief internal analyzer field. */ diff --git a/src/arithmetic/ir_visitor_with_analyzer.h b/src/arithmetic/ir_visitor_with_analyzer.h index 918f2e89501f..b8750df64afe 100644 --- a/src/arithmetic/ir_visitor_with_analyzer.h +++ b/src/arithmetic/ir_visitor_with_analyzer.h @@ -27,43 +27,43 @@ #include #include -#include +#include namespace tvm { namespace ir { -class IRVisitorWithAnalyzer final : public IRVisitor { +class IRVisitorWithAnalyzer final : public StmtExprVisitor { public: Expr Simplify(const Expr& expr) { return analyzer_.Simplify(expr); } - void Visit_(const For* op) { + void VisitStmt_(const For* op) { analyzer_.Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); - return IRVisitor::Visit_(op); + return StmtExprVisitor::VisitStmt_(op); } - void Visit_(const AttrStmt* op) { + void VisitStmt_(const AttrStmt* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); analyzer_.Bind(iv->var, Range::make_by_min_extent(0, op->value)); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } } - void Visit_(const Reduce* op) { + void VisitExpr_(const Reduce* op) { // Setup the domain information before simplification. for (const IterVar& iv : op->axis) { analyzer_.Bind(iv->var, iv->dom); } // Recursively call simplification when necessary. - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); } protected: diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 235306cc7bf8..f883bf145f59 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -24,7 +24,6 @@ // Acknowledgement: Most rewrite-rules are from Halide. #include #include -#include #include #include "const_fold.h" #include "pattern_match.h" @@ -69,7 +68,7 @@ using namespace ir; // try to prove x equals val RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl:: TryCompare(const Expr& x, int64_t val) { - Expr diff = Mutate(x); + Expr diff = this->VisitExpr(x); if (const auto* ptr = diff.as()) { if (ptr->value == val) { return kEQ; @@ -117,8 +116,8 @@ Update(const Var& var, const Expr& info, bool override) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Add* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Add* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -232,8 +231,8 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const Expr& const } Expr RewriteSimplifier::Impl:: -Mutate_(const Sub* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Sub* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -431,8 +430,8 @@ Mutate_(const Sub* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Mul* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Mul* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -470,8 +469,8 @@ Mutate_(const Mul* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Div* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Div* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as
(); Expr const_res = TryConstFold
(op->a, op->b); if (const_res.defined()) return const_res; @@ -692,8 +691,8 @@ Mutate_(const Div* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Mod* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Mod* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -782,8 +781,8 @@ Mutate_(const Mod* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const FloorDiv* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const FloorDiv* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -926,8 +925,8 @@ Mutate_(const FloorDiv* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const FloorMod* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const FloorMod* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -996,8 +995,8 @@ Mutate_(const FloorMod* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Min* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Min* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -1181,8 +1180,8 @@ Mutate_(const Min* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Max* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Max* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -1354,8 +1353,8 @@ Mutate_(const Max* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const EQ* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const EQ* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -1388,28 +1387,28 @@ Mutate_(const EQ* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const NE* op, const Expr& self) { - return Mutate(Not::make(op->a == op->b)); +VisitExpr_(const NE* op) { + return this->VisitExpr(Not::make(op->a == op->b)); } Expr RewriteSimplifier::Impl:: -Mutate_(const LE* op, const Expr& self) { - return Mutate(Not::make(op->b < op->a)); +VisitExpr_(const LE* op) { + return this->VisitExpr(Not::make(op->b < op->a)); } Expr RewriteSimplifier::Impl:: -Mutate_(const GT* op, const Expr& self) { - return Mutate(op->b < op->a); +VisitExpr_(const GT* op) { + return this->VisitExpr(op->b < op->a); } Expr RewriteSimplifier::Impl:: -Mutate_(const GE* op, const Expr& self) { - return Mutate(Not::make(op->a < op->b)); +VisitExpr_(const GE* op) { + return this->VisitExpr(Not::make(op->a < op->b)); } Expr RewriteSimplifier::Impl:: -Mutate_(const LT* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const LT* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -1564,8 +1563,8 @@ Mutate_(const LT* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Not* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Not* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a); if (const_res.defined()) return const_res; @@ -1589,8 +1588,8 @@ Mutate_(const Not* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const And* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const And* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -1638,8 +1637,8 @@ Mutate_(const And* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Or* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Or* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -1688,8 +1687,8 @@ Mutate_(const Or* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Select* op, const Expr& self) { - Expr ret = IRMutatorWithAnalyzer::Mutate_(op, self); +VisitExpr_(const Select* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); UnsafeExprDetector unsafe; bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar(); @@ -131,7 +130,7 @@ class UnsafeSelectRewriter : public IRMutator { }; Stmt RewriteUnsafeSelect(Stmt stmt) { - return UnsafeSelectRewriter().Mutate(stmt); + return UnsafeSelectRewriter()(std::move(stmt)); } } // namespace ir diff --git a/src/pass/simple_passes.cc b/src/pass/simple_passes.cc index 1159e568f519..e9ed89304167 100644 --- a/src/pass/simple_passes.cc +++ b/src/pass/simple_passes.cc @@ -22,25 +22,24 @@ * \brief Implementation of simple passes */ #include -#include -#include +#include #include namespace tvm { namespace ir { -class IRSideEffect : public IRVisitor { +class IRSideEffect : public ExprVisitor { public: - void Visit(const ObjectRef& e) final { + void VisitExpr(const Expr& e) final { if (has_side_effect_) return; - IRVisitor::Visit(e); + ExprVisitor::VisitExpr(e); } - void Visit_(const Call* op) final { + void VisitExpr_(const Call* op) final { if (!op->is_pure()) { has_side_effect_ = true; return; } else { - IRVisitor::Visit_(op); + ExprVisitor::VisitExpr_(op); } } @@ -49,23 +48,23 @@ class IRSideEffect : public IRVisitor { bool HasSideEffect(const Expr& e) { IRSideEffect v; - v.Visit(e); + v(e); return v.has_side_effect_; } -class IRSubstitue : public IRMutator { +class IRSubstitue : public StmtExprMutator { public: explicit IRSubstitue( const std::unordered_map& smap) : smap_(smap) { } - Expr Mutate_(const Variable* op, const Expr& e) final { + Expr VisitExpr_(const Variable* op) final { auto it = smap_.find(op); if (it != smap_.end()) { return it->second; } else { - return e; + return GetRef(op); } } @@ -76,13 +75,13 @@ class IRSubstitue : public IRMutator { Stmt Substitute(Stmt stmt, const std::unordered_map& value_map) { if (value_map.size() == 0) return stmt; - return IRSubstitue(value_map).Mutate(stmt); + return IRSubstitue(value_map)(std::move(stmt)); } Expr Substitute(Expr expr, const std::unordered_map& value_map) { if (value_map.size() == 0) return expr; - return IRSubstitue(value_map).Mutate(expr); + return IRSubstitue(value_map)(std::move(expr)); } Stmt Substitute(Stmt stmt, const Map& value_map) { @@ -101,20 +100,20 @@ Expr Substitute(Expr expr, const Map& value_map) { return Substitute(expr, vmap); } -class VarTouchVisitor : public IRVisitor { +class VarTouchVisitor : public ExprVisitor { public: - void Visit(const ObjectRef& e) final { + void VisitExpr(const Expr& e) final { if (use_var_) return; - IRVisitor::Visit(e); + ExprVisitor::VisitExpr(e); } - void Visit_(const Variable* op) final { + void VisitExpr_(const Variable* op) final { Handle(op); } - void Visit_(const Load* op) final { + void VisitExpr_(const Load* op) final { Handle(op->buffer_var.get()); - IRVisitor::Visit_(op); + ExprVisitor::VisitExpr_(op); } virtual void Handle(const Variable* var) = 0; @@ -149,14 +148,14 @@ class ExprUseVSetVisitor : public VarTouchVisitor { bool ExprUseVar(const Expr& e, const Var& v) { ExprUseVarVisitor visitor(v.get()); - visitor.Visit(e); + visitor(e); return visitor.use_var_; } bool ExprUseVar(const Expr& e, const std::unordered_set& vset) { ExprUseVSetVisitor visitor(vset); - visitor.Visit(e); + visitor(e); return visitor.use_var_; } diff --git a/src/pass/skip_assert.cc b/src/pass/skip_assert.cc index 817416d9fd2c..47a158f1d760 100644 --- a/src/pass/skip_assert.cc +++ b/src/pass/skip_assert.cc @@ -19,22 +19,22 @@ #include #include -#include +#include namespace tvm { namespace ir { -class AssertSkipper : public IRMutator { +class AssertSkipper : public StmtMutator { public: - Stmt Mutate_(const AssertStmt* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const AssertStmt* op) final { + Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); return op->body; } }; Stmt SkipAssert(Stmt stmt) { - return AssertSkipper().Mutate(stmt); + return AssertSkipper()(std::move(stmt)); } LoweredFunc SkipAssert(LoweredFunc f) { diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc index f045c271456c..2a7c75e04eb5 100644 --- a/src/pass/split_host_device.cc +++ b/src/pass/split_host_device.cc @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include @@ -32,9 +32,9 @@ namespace tvm { namespace ir { // use/def analysis, also delete unreferenced lets -class IRUseDefAnalysis : public IRMutator { +class IRUseDefAnalysis : public StmtExprMutator { public: - Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); @@ -48,75 +48,77 @@ class IRUseDefAnalysis : public IRMutator { Expr value = op->value; if (visit_thread_extent_) { - value = this->Mutate(value); + value = this->VisitExpr(value); + } + Stmt body = this->VisitStmt(op->body); + if (value.same_as(op->value) && body.same_as(op->body)) { + return GetRef(op); } - Stmt body = this->Mutate(op->body); - if (value.same_as(op->value) && body.same_as(op->body)) return s; return AttrStmt::make(op->node, op->attr_key, value, body); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const LetStmt *op, const Stmt& s) final { + Stmt VisitStmt_(const LetStmt* op) final { this->HandleDef(op->var.get()); - Stmt body = this->Mutate(op->body); + Stmt body = this->VisitStmt(op->body); // eliminate unreferenced let if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) { return body; } else { - Expr value = this->Mutate(op->value); + Expr value = this->VisitExpr(op->value); if (body.same_as(op->body) && value.same_as(op->value)) { - return s; + return GetRef(op); } else { return LetStmt::make(op->var, value, body); } } } - Stmt Mutate_(const For *op, const Stmt& s) final { + Stmt VisitStmt_(const For* op) final { this->HandleDef(op->loop_var.get()); - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } - Stmt Mutate_(const Allocate *op, const Stmt& s) final { + Stmt VisitStmt_(const Allocate* op) final { this->HandleDef(op->buffer_var.get()); - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } - Stmt Mutate_(const Store *op, const Stmt& s) final { + Stmt VisitStmt_(const Store* op) final { this->HandleUse(op->buffer_var); - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } - Expr Mutate_(const Let *op, const Expr& e) final { + Expr VisitExpr_(const Let* op) final { this->HandleDef(op->var.get()); - Expr body = this->Mutate(op->body); + Expr body = this->VisitExpr(op->body); // eliminate unreferenced let if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) { return body; } else { - Expr value = this->Mutate(op->value); + Expr value = this->VisitExpr(op->value); if (body.same_as(op->body) && value.same_as(op->value)) { - return e; + return GetRef(op); } else { return Let::make(op->var, value, body); } } } - Expr Mutate_(const Variable *op, const Expr& e) final { - this->HandleUse(e); - return IRMutator::Mutate_(op, e); + Expr VisitExpr_(const Variable* op) final { + this->HandleUse(GetRef(op)); + return StmtExprMutator::VisitExpr_(op); } - Expr Mutate_(const Load *op, const Expr& e) final { + Expr VisitExpr_(const Load* op) final { this->HandleUse(op->buffer_var); - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } void HandleDef(const Variable* v) { @@ -154,20 +156,20 @@ class IRUseDefAnalysis : public IRMutator { std::unordered_map def_count_; }; -class HostDeviceSplitter : public IRMutator { +class HostDeviceSplitter : public StmtMutator { public: - Stmt Mutate_(const Allocate* op, const Stmt& s) final { + Stmt VisitStmt_(const Allocate* op) final { handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0); - return IRMutator::Mutate_(op, s); + return StmtMutator::VisitStmt_(op); } - Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || op->attr_key == attr::device_scope) { - return SplitDeviceFunc(s); + return SplitDeviceFunc(GetRef(op)); } - return IRMutator::Mutate_(op, s); + return StmtMutator::VisitStmt_(op); } Array Split(LoweredFunc f) { @@ -178,7 +180,7 @@ class HostDeviceSplitter : public IRMutator { name_ = f->name; ObjectPtr n = make_object(*f.operator->()); - n->body = this->Mutate(f->body); + n->body = operator()(f->body); n->func_type = kHostFunc; Array ret{LoweredFunc(n)}; for (LoweredFunc x : device_funcs_) { @@ -195,7 +197,7 @@ class HostDeviceSplitter : public IRMutator { // isolate the device function. IRUseDefAnalysis m; m.visit_thread_extent_ = false; - n->body = m.Mutate(body); + n->body = m(std::move(body)); n->name = os.str(); n->func_type = kDeviceFunc; n->thread_axis = m.thread_axis_; @@ -243,7 +245,7 @@ Array UndefinedVars(const Stmt& stmt, const Array& args) { for (Var arg : args) { m.use_count_[arg.get()] = 0; } - m.Mutate(stmt); + m(stmt); return m.undefined_; } diff --git a/src/pass/ssa.cc b/src/pass/ssa.cc index 37db29c58079..94f045d94561 100644 --- a/src/pass/ssa.cc +++ b/src/pass/ssa.cc @@ -24,8 +24,7 @@ * \file ssa.cc */ #include -#include -#include +#include #include #include #include @@ -34,29 +33,33 @@ namespace tvm { namespace ir { namespace { -class IRVerifySSA final : public IRVisitor { +class IRVerifySSA final : public StmtExprVisitor { public: bool is_ssa{true}; - void Visit(const ObjectRef& n) final { + void VisitExpr(const Expr& n) final { if (!is_ssa) return; - IRVisitor::Visit(n); + StmtExprVisitor::VisitExpr(n); } - void Visit_(const Let* op) final { + void VisitStmt(const Stmt& n) final { + if (!is_ssa) return; + StmtExprVisitor::VisitStmt(n); + } + void VisitExpr_(const Let* op) final { MarkDef(op->var.get()); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); } - void Visit_(const LetStmt* op) final { + void VisitStmt_(const LetStmt* op) final { MarkDef(op->var.get()); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } - void Visit_(const For* op) final { + void VisitStmt_(const For* op) final { MarkDef(op->loop_var.get()); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } - void Visit_(const Allocate* op) final { + void VisitStmt_(const Allocate* op) final { MarkDef(op->buffer_var.get()); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } private: @@ -70,31 +73,32 @@ class IRVerifySSA final : public IRVisitor { std::unordered_map defined_; }; -class IRConvertSSA final : public IRMutator { + +class IRConvertSSA final : public StmtExprMutator { public: - Expr Mutate_(const Variable* op, const Expr& e) final { + Expr VisitExpr_(const Variable* op) final { if (scope_.count(op)) { return scope_[op].back(); } else { - return e; + return GetRef(op); } } - Expr Mutate_(const Let* op, const Expr& e) final { + Expr VisitExpr_(const Let* op) final { const VarExpr& v = op->var; if (defined_.count(v.get())) { - Expr value = IRMutator::Mutate(op->value); + Expr value = this->VisitExpr(op->value); VarExpr new_var = Variable::make(v.dtype(), v->name_hint); scope_[v.get()].push_back(new_var); - Expr body = IRMutator::Mutate(op->body); + Expr body = this->VisitExpr(op->body); scope_[v.get()].pop_back(); return Let::make(new_var, value, body); } else { defined_.insert(v.get()); - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } } - Expr Mutate_(const Load* op, const Expr& e) final { - Expr expr = IRMutator::Mutate_(op, e); + Expr VisitExpr_(const Load* op) final { + Expr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (scope_.count(op->buffer_var.get())) { return Load::make( @@ -104,8 +108,8 @@ class IRConvertSSA final : public IRMutator { return expr; } } - Stmt Mutate_(const Store* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const Store* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(op->buffer_var.get())) { return Store::make( @@ -115,41 +119,41 @@ class IRConvertSSA final : public IRMutator { return stmt; } } - Stmt Mutate_(const LetStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const LetStmt* op) final { const VarExpr& v = op->var; if (defined_.count(v.get())) { - Expr value = IRMutator::Mutate(op->value); + Expr value = this->VisitExpr(op->value); VarExpr new_var = Variable::make(v.dtype(), v->name_hint); scope_[v.get()].push_back(new_var); - Stmt body = IRMutator::Mutate(op->body); + Stmt body = this->VisitStmt(op->body); scope_[v.get()].pop_back(); return LetStmt::make(new_var, value, body); } else { defined_.insert(v.get()); - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const For* op, const Stmt& s) final { + Stmt VisitStmt_(const For* op) final { const VarExpr& v = op->loop_var; if (defined_.count(v.get())) { VarExpr new_var = Variable::make(v.dtype(), v->name_hint); scope_[v.get()].push_back(new_var); - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); return For::make( new_var, op->min, op->extent, op->for_type, op->device_api, op->body); } else { defined_.insert(v.get()); - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const Allocate* op, const Stmt& s) final { + Stmt VisitStmt_(const Allocate* op) final { const VarExpr& v = op->buffer_var; if (defined_.count(v.get())) { VarExpr new_var = Variable::make(v.dtype(), v->name_hint); scope_[v.get()].push_back(new_var); - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); return Allocate::make( @@ -157,23 +161,23 @@ class IRConvertSSA final : public IRMutator { op->body, op->new_expr, op->free_function); } else { defined_.insert(v.get()); - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (const Variable* v = op->node.as()) { if (op->attr_key == attr::storage_scope) { const Allocate* alloc = op->body.as(); if (alloc && op->node.same_as(alloc->buffer_var)) { - Stmt new_alloc = Mutate(op->body); - if (new_alloc.same_as(op->body)) return s; + Stmt new_alloc = this->VisitStmt(op->body); + if (new_alloc.same_as(op->body)) return GetRef(op); alloc = new_alloc.as(); CHECK(alloc); return AttrStmt::make( alloc->buffer_var, op->attr_key, op->value, new_alloc); } } - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(v) && scope_[v].size() != 0) { return AttrStmt::make( @@ -182,7 +186,7 @@ class IRConvertSSA final : public IRMutator { return stmt; } } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } @@ -194,13 +198,13 @@ class IRConvertSSA final : public IRMutator { } // namespace bool VerifySSA(const Stmt& ir) { - IRVerifySSA v; - v.Visit(ir); - return v.is_ssa; + IRVerifySSA visitor; + visitor(ir); + return visitor.is_ssa; } Stmt ConvertSSA(Stmt stmt) { - return IRConvertSSA().Mutate(stmt); + return IRConvertSSA()(std::move(stmt)); } } // namespace ir diff --git a/src/pass/storage_access.cc b/src/pass/storage_access.cc index bf8d4e020521..0d594046e5df 100644 --- a/src/pass/storage_access.cc +++ b/src/pass/storage_access.cc @@ -21,7 +21,6 @@ * \file storage_access.cc */ #include -#include #include #include #include @@ -32,7 +31,7 @@ namespace tvm { namespace ir { -void StorageAccessVisitor::Visit_(const Load* op) { +void StorageAccessVisitor::VisitExpr_(const Load* op) { const Variable* buf = op->buffer_var.as(); StorageScope scope = GetScope(buf); if (Enabled(buf, scope)) { @@ -47,10 +46,10 @@ void StorageAccessVisitor::Visit_(const Load* op) { curr_stmt_.access.emplace_back(std::move(e)); } // traverse child - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); } -void StorageAccessVisitor::Visit_(const Store* op) { +void StorageAccessVisitor::VisitStmt_(const Store* op) { allow_append_ = true; CHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; @@ -67,7 +66,7 @@ void StorageAccessVisitor::Visit_(const Store* op) { curr_stmt_.access.emplace_back(std::move(e)); } // traverse child - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); // push to the scope scope_.back().push_back(curr_stmt_); // clear access entry. @@ -75,11 +74,11 @@ void StorageAccessVisitor::Visit_(const Store* op) { allow_append_ = false; } -void StorageAccessVisitor::Visit_(const Evaluate* op) { +void StorageAccessVisitor::VisitStmt_(const Evaluate* op) { allow_append_ = true; CHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); // push to the scope if (curr_stmt_.access.size() != 0) { scope_.back().push_back(curr_stmt_); @@ -88,17 +87,17 @@ void StorageAccessVisitor::Visit_(const Evaluate* op) { allow_append_ = false; } -void StorageAccessVisitor::Visit_(const AttrStmt* op) { +void StorageAccessVisitor::VisitStmt_(const AttrStmt* op) { if (op->attr_key == attr::storage_scope) { const Variable* buf = op->node.as(); storage_scope_[buf] = StorageScope::make(op->value.as()->value); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } else if (op->attr_key == attr::double_buffer_write) { CHECK(double_buffer_write_ == nullptr); double_buffer_write_ = op->node.as(); scope_.push_back(std::vector()); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); StmtEntry s; s.stmt = op; s.access = Summarize(std::move(scope_.back()), nullptr); @@ -115,7 +114,7 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) { } else if (op->attr_key == attr::coproc_scope) { IterVar iv = Downcast(op->node); env_threads_.push_back(iv); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); env_threads_.CopyOnWrite()->data.pop_back(); } else if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); @@ -123,23 +122,23 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) { if (!in_device_env_) { in_device_env_ = true; scope_.push_back(std::vector()); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); // no need to take the result as the thread barrier automatically syncs. Summarize(std::move(scope_.back()), nullptr); in_device_env_ = false; scope_.pop_back(); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } env_threads_.CopyOnWrite()->data.pop_back(); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } } -void StorageAccessVisitor::Visit_(const For* op) { +void StorageAccessVisitor::VisitStmt_(const For* op) { scope_.push_back(std::vector()); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); StmtEntry s; s.stmt = op; s.access = Summarize(std::move(scope_.back()), op); @@ -161,11 +160,11 @@ void StorageAccessVisitor::Visit_(const For* op) { } } -void StorageAccessVisitor::Visit_(const IfThenElse* op) { +void StorageAccessVisitor::VisitStmt_(const IfThenElse* op) { ++condition_counter_; - this->Visit(op->condition); + this->VisitExpr(op->condition); scope_.push_back(std::vector()); - this->Visit(op->then_case); + this->VisitStmt(op->then_case); StmtEntry s; s.stmt = op; s.access = Summarize(std::move(scope_.back()), nullptr); @@ -180,10 +179,10 @@ void StorageAccessVisitor::Visit_(const IfThenElse* op) { --condition_counter_; } -void StorageAccessVisitor::Visit_(const Call* op) { +void StorageAccessVisitor::VisitExpr_(const Call* op) { if (op->is_intrinsic(intrinsic::tvm_address_of)) { const Load *l = op->args[0].as(); - IRVisitor::Visit_(l); + StmtExprVisitor::VisitExpr_(l); } else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); @@ -211,7 +210,7 @@ void StorageAccessVisitor::Visit_(const Call* op) { curr_stmt_.access.emplace_back(e); } } - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { CHECK(allow_append_); const std::string& s = op->args[0].as()->value; @@ -224,7 +223,7 @@ void StorageAccessVisitor::Visit_(const Call* op) { curr_stmt_.access.emplace_back(std::move(e)); } } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); } } @@ -236,11 +235,12 @@ StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const { return it->second; } -class StorageAccessInfoLower : public IRMutator { + +class StorageAccessInfoLower : public StmtExprMutator { public: - Stmt Mutate_(const Allocate* op, const Stmt& s) final { + Stmt VisitStmt_(const Allocate* op) final { // Lower allocate to device allocate when needed. - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); // For special memory, remove allocate, or use head expr auto it = storage_info_.find(op->buffer_var.get()); @@ -259,7 +259,7 @@ class StorageAccessInfoLower : public IRMutator { return stmt; } } - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::storage_scope) { const Variable* buf = op->node.as(); StorageScope scope = StorageScope::make(op->value.as()->value); @@ -270,26 +270,26 @@ class StorageAccessInfoLower : public IRMutator { CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string(); } storage_info_[buf] = e; - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Expr Mutate_(const Call* op, const Expr &e) final { + Expr VisitExpr_(const Call* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { - return MakeAccessPtr(op, e); + return MakeAccessPtr(op); } else { - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } } private: // tvm_access_ptr - Expr MakeAccessPtr(const Call* op, const Expr& e) { + Expr MakeAccessPtr(const Call* op) { // Specially handle the buffer packed intrinsic - Expr expr = IRMutator::Mutate_(op, e); + Expr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); @@ -337,7 +337,7 @@ class StorageAccessInfoLower : public IRMutator { }; Stmt LowerStorageAccessInfo(Stmt stmt) { - return StorageAccessInfoLower().Mutate(stmt); + return StorageAccessInfoLower()(std::move(stmt)); } LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) { diff --git a/src/pass/storage_access.h b/src/pass/storage_access.h index 302ca929581d..12bf9f369de3 100644 --- a/src/pass/storage_access.h +++ b/src/pass/storage_access.h @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include #include "../runtime/thread_storage_scope.h" @@ -40,7 +40,7 @@ using runtime::StorageRank; /*! * \brief Base class of storage access analysis */ -class StorageAccessVisitor : public IRVisitor { +class StorageAccessVisitor : public StmtExprVisitor { public: /*! \brief Storage access type */ enum AccessType { @@ -76,13 +76,13 @@ class StorageAccessVisitor : public IRVisitor { std::vector access; }; // override visitor pattern - void Visit_(const Load* op) final; - void Visit_(const Store* op) final; - void Visit_(const Evaluate* op) final; - void Visit_(const AttrStmt* op) final; - void Visit_(const For* op) final; - void Visit_(const IfThenElse* op) final; - void Visit_(const Call* op) final; + void VisitExpr_(const Load* op) final; + void VisitStmt_(const Store* op) final; + void VisitStmt_(const Evaluate* op) final; + void VisitStmt_(const AttrStmt* op) final; + void VisitStmt_(const For* op) final; + void VisitStmt_(const IfThenElse* op) final; + void VisitExpr_(const Call* op) final; protected: StorageAccessVisitor() { diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 2df2672adcb1..6bb3fc5a6025 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -26,8 +26,7 @@ #include #include #include -#include -#include +#include #include #include #include @@ -48,7 +47,7 @@ using runtime::StorageScope; using runtime::ThreadScope; using intrinsic::tvm_address_of; -class StorageFlattener : public IRMutator { +class StorageFlattener : public StmtExprMutator { public: explicit StorageFlattener(Map extern_buffer, int cache_line_size, bool create_bound_attributes, @@ -64,8 +63,8 @@ class StorageFlattener : public IRMutator { cache_line_size_ = cache_line_size; } - Stmt Mutate_(const Store* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const Store* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); auto it = var_remap_.find(op->buffer_var.get()); if (it != var_remap_.end() && @@ -78,14 +77,14 @@ class StorageFlattener : public IRMutator { } } - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::realize_scope) { storage_scope_[op->node.get()] = op->value.as()->value; - return this->Mutate(op->body); + return this->VisitStmt(op->body); } else if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance()) { Operation func = Downcast(op->node); - Stmt body = Mutate(op->body); + Stmt body = this->VisitStmt(op->body); for (int i = 0; i < func->num_outputs(); ++i) { TensorKey key{func, i}; auto it = buf_map_.find(key); @@ -99,7 +98,7 @@ class StorageFlattener : public IRMutator { IterVar iv = Downcast(op->node); ThreadScope ts = ThreadScope::make(iv->thread_tag); curr_thread_scope_.push_back(ts); - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); curr_thread_scope_.pop_back(); return stmt; } else if (op->attr_key == attr::buffer_bind_scope) { @@ -116,17 +115,17 @@ class StorageFlattener : public IRMutator { } vinfo[dim].align_factor = tuple->args[1].as()->value; vinfo[dim].align_offset = tuple->args[2].as()->value; - return this->Mutate(op->body); + return this->VisitStmt(op->body); } else if (op->attr_key == attr::opengl_stage_scope) { is_opengl_ = true; } - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } - Stmt Mutate_(const Provide* op, const Stmt& s) final { + Stmt VisitStmt_(const Provide* op) final { if (create_bound_attributes_) shape_collector_.clear(); - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); TensorKey key{op->func, op->value_index}; auto it = buf_map_.find(key); @@ -159,11 +158,11 @@ class StorageFlattener : public IRMutator { } } - Stmt Mutate_(const Realize* op, const Stmt& s) final { + Stmt VisitStmt_(const Realize* op) final { TensorKey key{op->func, op->value_index}; if (buf_map_.count(key)) { CHECK(buf_map_.at(key).external); - return this->Mutate(op->body); + return this->VisitStmt(op->body); } else { // create a buffer entry BufferEntry e; @@ -226,7 +225,7 @@ class StorageFlattener : public IRMutator { align, 0, kDefault); buf_map_[key] = e; - Stmt body = this->Mutate(op->body); + Stmt body = this->VisitStmt(op->body); buf_map_[key].released = true; Stmt ret; @@ -263,8 +262,8 @@ class StorageFlattener : public IRMutator { } } - Expr Mutate_(const Load* op, const Expr& e) final { - Expr expr = IRMutator::Mutate_(op, e); + Expr VisitExpr_(const Load* op) final { + Expr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); auto it = var_remap_.find(op->buffer_var.get()); if (it != var_remap_.end() && @@ -277,17 +276,17 @@ class StorageFlattener : public IRMutator { } } - Expr Mutate_(const Variable* op, const Expr& e) final { + Expr VisitExpr_(const Variable* op) final { auto it = var_remap_.find(op); if (it != var_remap_.end()) { return it->second; } else { - return e; + return GetRef(op); } } - Expr Mutate_(const Call* op, const Expr& olde) final { - Expr expr = IRMutator::Mutate_(op, olde); + Expr VisitExpr_(const Call* op) final { + Expr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (op != nullptr && op->call_type == Call::Halide) { TensorKey key{op->func, op->value_index}; @@ -308,8 +307,8 @@ class StorageFlattener : public IRMutator { } } - Stmt Mutate_(const Prefetch *op, const Stmt &s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const Prefetch *op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); CHECK(op != nullptr); TensorKey key{op->func, op->value_index}; @@ -443,7 +442,7 @@ class StorageFlattener : public IRMutator { // Apply the remaps Stmt body = MergeNest(binder.asserts(), op->body); body = MergeNest(binder.init_nest(), body); - body = this->Mutate(body); + body = this->VisitStmt(body); // remove the binds for (const Var& v : binder.defs()) { var_remap_.erase(v.get()); @@ -531,10 +530,10 @@ class StorageFlattener : public IRMutator { Stmt StorageFlatten(Stmt stmt, Map extern_buffer, int cache_line_size, bool create_bound_attributes) { IRVisitorWithAnalyzer bounded_analyzer; - bounded_analyzer.Visit(stmt); + bounded_analyzer(stmt); stmt = StorageFlattener(extern_buffer, cache_line_size, - create_bound_attributes, &bounded_analyzer).Mutate(stmt); + create_bound_attributes, &bounded_analyzer)(std::move(stmt)); return stmt; } diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 01c6f983d692..c820c477e128 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -24,8 +24,7 @@ */ #include #include -#include -#include +#include #include #include #include @@ -54,7 +53,7 @@ using runtime::StorageScope; // The storage need to be kept alive between allocate and last access. // The free point is only inserted at the same scope of allocate. // -class LinearAccessPatternFinder final : public IRVisitor { +class LinearAccessPatternFinder final : public StmtExprVisitor { public: /*! \brief record the touch hist of statment. */ struct StmtEntry { @@ -78,7 +77,7 @@ class LinearAccessPatternFinder final : public IRVisitor { const Allocate* alloc{nullptr}; }; - void Visit_(const Allocate* op) final { + void VisitStmt_(const Allocate* op) final { size_t level = scope_.size(); const Variable* buf = op->buffer_var.get(); auto it = alloc_info_.find(buf); @@ -86,12 +85,12 @@ class LinearAccessPatternFinder final : public IRVisitor { CHECK(it->second.alloc == nullptr); it->second.alloc = op; it->second.level = level; - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } - void Visit_(const Store* op) final { + void VisitStmt_(const Store* op) final { scope_.push_back(StmtEntry()); // visit subexpr - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); // Add write access. const Variable* buf = op->buffer_var.get(); auto it = alloc_info_.find(buf); @@ -106,10 +105,10 @@ class LinearAccessPatternFinder final : public IRVisitor { linear_seq_.push_back(e); } } - void Visit_(const Evaluate* op) final { + void VisitStmt_(const Evaluate* op) final { scope_.push_back(StmtEntry()); // visit subexpr - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); StmtEntry e = scope_.back(); scope_.pop_back(); if (e.touched.size() != 0) { @@ -117,9 +116,9 @@ class LinearAccessPatternFinder final : public IRVisitor { linear_seq_.push_back(e); } } - void Visit_(const Load* op) final { + void VisitExpr_(const Load* op) final { // Add write access. - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); const Variable* buf = op->buffer_var.get(); auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { @@ -128,15 +127,15 @@ class LinearAccessPatternFinder final : public IRVisitor { scope_[it->second.level].touched.push_back(buf); } } - void Visit_(const Call* op) final { + void VisitExpr_(const Call* op) final { if (op->is_intrinsic(intrinsic::tvm_address_of)) { const Load* l = op->args[0].as(); - this->Visit(l->index); + this->VisitExpr(l->index); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); } } - void Visit_(const Variable* buf) final { + void VisitExpr_(const Variable* buf) final { // Directly reference to the variable count as a read. auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { @@ -153,7 +152,7 @@ class LinearAccessPatternFinder final : public IRVisitor { int64_t begin_index = static_cast(linear_seq_.size()); // before scope. linear_seq_.push_back(e); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); // after scope. e.touched = std::move(scope_.back().touched); scope_.pop_back(); @@ -165,7 +164,7 @@ class LinearAccessPatternFinder final : public IRVisitor { CHECK_NE(end_index, 0U); linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; } - void Visit_(const AttrStmt* op) final { + void VisitStmt_(const AttrStmt* op) final { // Only record the outer most thread extent. if (op->attr_key == attr::thread_extent && !in_thread_env_) { in_thread_env_ = true; @@ -179,20 +178,20 @@ class LinearAccessPatternFinder final : public IRVisitor { const Variable* buf = op->node.as(); alloc_info_[buf].storage_scope = StorageScope::make(op->value.as()->value); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } } - void Visit_(const IfThenElse* op) final { + void VisitStmt_(const IfThenElse* op) final { VisitNewScope(op); } - void Visit_(const For* op) final { + void VisitStmt_(const For* op) final { VisitNewScope(op); } - void Visit_(const AssertStmt* op) final { + void VisitStmt_(const AssertStmt* op) final { VisitNewScope(op); } @@ -234,7 +233,7 @@ class LinearAccessPatternFinder final : public IRVisitor { // // The code after inplace transformation is no longer idempotent. // -class InplaceOpVerifier : public IRVisitor { +class InplaceOpVerifier : public StmtExprVisitor { public: bool Check(const Object* stmt, const Variable* dst, @@ -243,58 +242,62 @@ class InplaceOpVerifier : public IRVisitor { src_ = src; result_ = true; if (stmt->IsInstance()) { - Visit_(static_cast(stmt)); + VisitStmt_(static_cast(stmt)); } else if (stmt->IsInstance()) { - Visit_(static_cast(stmt)); + VisitStmt_(static_cast(stmt)); } else if (stmt->IsInstance()) { - Visit_(static_cast(stmt)); + VisitStmt_(static_cast(stmt)); } else if (stmt->IsInstance()) { - Visit_(static_cast(stmt)); + VisitStmt_(static_cast(stmt)); } else { return false; } return result_; } - using IRVisitor::Visit_; + using StmtExprVisitor::VisitStmt_; - void Visit(const ObjectRef& e) final { + void VisitStmt(const Stmt& n) final { if (!result_) return; - IRVisitor::Visit(e); + StmtExprVisitor::VisitStmt(n); + } + void VisitExpr(const Expr& n) final { + if (!result_) return; + StmtExprVisitor::VisitExpr(n); } - void Visit_(const Variable* op) final { + void VisitExpr_(const Variable* op) final { // assume all opaque access is unsafe if (op == dst_ || op == src_) { result_ = false; return; } } - void Visit_(const Store* op) final { + void VisitStmt_(const Store* op) final { ++mem_nest_; - this->Visit(op->index); + this->VisitExpr(op->index); --mem_nest_; if (op->buffer_var.get() == dst_) { store_ = op; - this->Visit(op->value); - this->Visit(op->predicate); + this->VisitExpr(op->value); + this->VisitExpr(op->predicate); store_ = nullptr; } else { - this->Visit(op->value); - this->Visit(op->predicate); + this->VisitExpr(op->value); + this->VisitExpr(op->predicate); } } - void Visit_(const AttrStmt* op) final { + void VisitStmt_(const AttrStmt* op) final { // always reject extern code if (op->attr_key == attr::extern_scope || op->attr_key == attr::volatile_scope) { result_ = false; return; } - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } - void Visit_(const Load* op) final { + void VisitExpr_(const Load* op) final { const Variable* buf = op->buffer_var.get(); // cannot read from dst_ (no reduction) if (buf == dst_) { @@ -312,7 +315,7 @@ class InplaceOpVerifier : public IRVisitor { } } ++mem_nest_; - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); --mem_nest_; } @@ -332,7 +335,7 @@ class InplaceOpVerifier : public IRVisitor { }; // Planner to plan and rewrite memory allocation. -class StoragePlanRewriter : public IRMutator { +class StoragePlanRewriter : public StmtExprMutator { public: using StmtEntry = LinearAccessPatternFinder::StmtEntry; using AllocEntry = LinearAccessPatternFinder::AllocEntry; @@ -341,12 +344,12 @@ class StoragePlanRewriter : public IRMutator { detect_inplace_ = detect_inplace; // plan the rewrite LinearAccessPatternFinder finder; - finder.Visit(stmt); + finder(stmt); this->LivenessAnalysis(finder.linear_seq_); this->PlanMemory(finder.linear_seq_, finder.alloc_info_); this->PrepareNewAlloc(); // start rewrite - stmt = this->Mutate(stmt); + stmt = operator()(std::move(stmt)); if (attach_map_.count(nullptr)) { std::vector nest; for (StorageEntry* e : attach_map_.at(nullptr)) { @@ -363,8 +366,8 @@ class StoragePlanRewriter : public IRMutator { } return stmt; } - Stmt Mutate_(const Store* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const Store* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); auto it = alloc_map_.find(op->buffer_var.get()); if (it == alloc_map_.end()) return stmt; @@ -373,8 +376,8 @@ class StoragePlanRewriter : public IRMutator { RemapIndex(op->value.dtype(), op->index, it->second), op->predicate); } - Expr Mutate_(const Load* op, const Expr& e) final { - Expr expr = IRMutator::Mutate_(op, e); + Expr VisitExpr_(const Load* op) final { + Expr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); auto it = alloc_map_.find(op->buffer_var.get()); if (it == alloc_map_.end()) return expr; @@ -383,7 +386,7 @@ class StoragePlanRewriter : public IRMutator { RemapIndex(op->dtype, op->index, it->second), op->predicate); } - Expr Mutate_(const Variable* op, const Expr& e) final { + Expr VisitExpr_(const Variable* op) final { auto it = alloc_map_.find(op); if (it != alloc_map_.end()) { if (it->second->bits_offset != 0) { @@ -391,79 +394,81 @@ class StoragePlanRewriter : public IRMutator { } return it->second->alloc_var; } else { - return e; + return GetRef(op); } } - Expr Mutate_(const Call* op, const Expr& e) final { + Expr VisitExpr_(const Call* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const Variable* buffer = op->args[1].as(); auto it = alloc_map_.find(buffer); - if (it == alloc_map_.end()) return IRMutator::Mutate_(op, e); - const StorageEntry* se = it->second; - Expr offset = Mutate(op->args[2]); - Expr extent = Mutate(op->args[3]); - uint64_t elem_bits = dtype.bits() * dtype.lanes(); - CHECK_EQ(se->bits_offset % elem_bits, 0U); - if (se->bits_offset != 0) { - offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset; - } - return Call::make( - op->dtype, op->name, - {op->args[0], se->alloc_var, offset, extent, op->args[4]}, - op->call_type); + if (it == alloc_map_.end()) { + return StmtExprMutator::VisitExpr_(op); + } + const StorageEntry* se = it->second; + Expr offset = this->VisitExpr(op->args[2]); + Expr extent = this->VisitExpr(op->args[3]); + uint64_t elem_bits = dtype.bits() * dtype.lanes(); + CHECK_EQ(se->bits_offset % elem_bits, 0U); + if (se->bits_offset != 0) { + offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset; + } + return Call::make( + op->dtype, op->name, + {op->args[0], se->alloc_var, offset, extent, op->args[4]}, + op->call_type); } else { - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } } - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::storage_scope) { - return this->Mutate(op->body); + return this->VisitStmt(op->body); } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || attr::IsPragmaKey(op->attr_key)) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); return AttrStmt::make( op->node, op->attr_key, op->value, MakeAttach(svec, op->body)); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } else if (op->attr_key == attr::volatile_scope) { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); auto it = alloc_map_.find(op->node.as()); if (it == alloc_map_.end()) return stmt; return AttrStmt::make( it->second->alloc_var, op->attr_key, op->value, op->body); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const For* op, const Stmt& s) final { + Stmt VisitStmt_(const For* op) final { CHECK(op->for_type != ForType::Vectorized) << "VectorizeLoop before LiftStorageAlloc"; // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); return For::make( op->loop_var, op->min, op->extent, op->for_type, op->device_api, MakeAttach(svec, op->body)); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const Allocate* op, const Stmt& s) final { - return this->Mutate(op->body); + Stmt VisitStmt_(const Allocate* op) final { + return this->VisitStmt(op->body); } private: @@ -929,28 +934,28 @@ class StoragePlanRewriter : public IRMutator { // Turn alloc into vector alloc // if all its access is the same vector type. -class VectorAllocRewriter : public IRMutator { +class VectorAllocRewriter : public StmtExprMutator { public: - Expr Mutate_(const Load* op, const Expr& e) final { + Expr VisitExpr_(const Load* op) final { UpdateTypeMap(op->buffer_var.get(), op->dtype); - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } - Stmt Mutate_(const Store* op, const Stmt& s) final { + Stmt VisitStmt_(const Store* op) final { UpdateTypeMap(op->buffer_var.get(), op->value.dtype()); - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } - Expr Mutate_(const Call* op, const Expr& e) final { + Expr VisitExpr_(const Call* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { DataType dtype = op->args[0].dtype(); const Variable* buffer = op->args[1].as(); UpdateTypeMap(buffer, dtype); } - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } - Stmt Mutate_(const Allocate* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const Allocate* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); const auto& tvec = acc_map_[op->buffer_var.get()]; @@ -989,7 +994,7 @@ class VectorAllocRewriter : public IRMutator { LoweredFunc PointerValueTypeRewrite(LoweredFunc f) { auto n = make_object(*f.operator->()); VectorAllocRewriter rewriter; - n->body = rewriter.Mutate(n->body); + n->body = rewriter(n->body); for (Var arg : f->args) { if (arg.dtype().is_handle()) { const auto& tvec = rewriter.acc_map_[arg.get()]; @@ -1010,8 +1015,8 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f) { } Stmt StorageRewrite(Stmt stmt) { - stmt = StoragePlanRewriter().Rewrite(stmt, true); - return VectorAllocRewriter().Mutate(stmt); + stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true); + return VectorAllocRewriter()(std::move(stmt)); } } // namespace ir } // namespace tvm diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc index 0f8bef8383f2..6ace4f7f85b4 100644 --- a/src/pass/storage_sync.cc +++ b/src/pass/storage_sync.cc @@ -22,8 +22,7 @@ */ #include #include -#include -#include +#include #include #include #include "ir_util.h" @@ -197,13 +196,13 @@ class ThreadSyncPlanner : public StorageAccessVisitor { StorageScope sync_scope_; }; -class ThreadSyncInserter : public IRMutator { +class ThreadSyncInserter : public StmtExprMutator { public: ThreadSyncInserter(StorageScope sync_scope, const std::unordered_set& syncs) : sync_scope_(sync_scope), syncs_(syncs) {} - Stmt Mutate(Stmt stmt) final { + Stmt VisitStmt(const Stmt& stmt) final { if (syncs_.size() == 0) return stmt; if (syncs_.count(stmt.get())) { Stmt barrier; @@ -216,33 +215,33 @@ class ThreadSyncInserter : public IRMutator { Call::Intrinsic)); } // Mutate after query, to avoid stmt change. - stmt = IRMutator::Mutate(stmt); - stmt = Block::make(barrier, stmt); + auto ret = StmtExprMutator::VisitStmt(stmt); + ret = Block::make(barrier, ret); + return ret; } else { - stmt = IRMutator::Mutate(stmt); + return StmtExprMutator::VisitStmt(stmt); } - return stmt; } - Expr Mutate_(const Load* op, const Expr& e) final { + Expr VisitExpr_(const Load* op) final { if (sync_scope_.rank == StorageRank::kGlobal && GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].read_count; } - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } - Stmt Mutate_(const Store* op, const Stmt& s) final { + Stmt VisitStmt_(const Store* op) final { if (sync_scope_.rank == StorageRank::kGlobal && GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].write_count; } - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::thread_extent) { bool temp = true; std::swap(temp, in_thread_env_); thread_extents_.push_back(op); - Stmt ret = IRMutator::Mutate_(op, s); + Stmt ret = StmtExprMutator::VisitStmt_(op); thread_extents_.pop_back(); std::swap(temp, in_thread_env_); // first thread scope. @@ -256,15 +255,15 @@ class ThreadSyncInserter : public IRMutator { const Variable* buf = op->node.as(); storage_scope_[buf] = StorageScope::make(op->value.as()->value); - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Expr Mutate_(const Call* op, const Expr& e) final { + Expr VisitExpr_(const Call* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { - Expr expr = IRMutator::Mutate_(op, e); + Expr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); CHECK_EQ(op->args.size(), 5U); const Variable* buffer_var = op->args[1].as(); @@ -280,7 +279,7 @@ class ThreadSyncInserter : public IRMutator { } return expr; } else { - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } } @@ -363,8 +362,8 @@ class ThreadSyncInserter : public IRMutator { Stmt ThreadSync(Stmt stmt, std::string storage_scope) { StorageScope sync_scope = StorageScope::make(storage_scope); ThreadSyncPlanner planner(sync_scope); - planner.Visit(stmt); - return ThreadSyncInserter(sync_scope, planner.syncs_inserted_).Mutate(stmt); + planner(stmt); + return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt)); } LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) { diff --git a/src/pass/tensor_core.cc b/src/pass/tensor_core.cc index 8dcc0e49e119..a3890cde773e 100644 --- a/src/pass/tensor_core.cc +++ b/src/pass/tensor_core.cc @@ -24,8 +24,7 @@ #include #include #include -#include -#include +#include #include #include #include @@ -73,7 +72,7 @@ Expr unpack_type_cast(const Expr &input, const DataType &target_type) { // MMAMatcher matches C = Cast(A)*Cast(B)+C, // where A & B are fp16/int8 local buffers, // and C is fp32/int32 local buffer. -class MMAMatcher: public IRVisitor { +class MMAMatcher: public StmtVisitor { public: explicit MMAMatcher(Map extern_buffer) { for (auto kv : extern_buffer) { @@ -84,22 +83,21 @@ class MMAMatcher: public IRVisitor { buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = bi; } } - using IRVisitor::Visit_; - void Visit_(const AttrStmt* op) final { + void VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::pragma_tensor_core) { tensor_core_on_ = true; - IRVisitor::Visit_(op); + StmtVisitor::VisitStmt_(op); } else if (op->attr_key == attr::realize_scope) { storage_scope_[op->node.get()] = op->value.as()->value; - Visit(op->body); + this->VisitStmt(op->body); } else { - IRVisitor::Visit_(op); + StmtVisitor::VisitStmt_(op); } } - void Visit_(const Provide* op) final { - IRVisitor::Visit_(op); + void VisitStmt_(const Provide* op) final { + StmtVisitor::VisitStmt_(op); auto it = buf_map_.find(TensorKey{op->func, op->value_index}); if (it == buf_map_.end()) { return; @@ -113,19 +111,19 @@ class MMAMatcher: public IRVisitor { } } - void Visit_(const Realize* op) final { + void VisitStmt_(const Realize* op) final { TensorKey key{op->func, op->value_index}; if (buf_map_.count(key)) { if (!buf_map_.at(key).external) { return; } - Visit(op->body); + this->VisitStmt(op->body); } else { BufferInfo bi; bi.name = key.GetName(); bi.dtype = op->dtype; buf_map_[key] = bi; - Visit(op->body); + this->VisitStmt(op->body); buf_map_[key].released = true; } } @@ -236,12 +234,11 @@ class MMAMatcher: public IRVisitor { // BodyVisitor visits the body stmt of original ComputeOp // to get the access indices of input matrices, // if it is recognized as matrix multiply. -class BodyVisitor : public IRVisitor { +class BodyVisitor : public StmtExprVisitor { public: BodyVisitor() {} - using IRVisitor::Visit_; - void Visit_(const Reduce* op) final { + void VisitExpr_(const Reduce* op) final { auto* comm_add = op->combiner->result[0].as(); if (comm_add == nullptr || op->combiner->result.size() > 1) { return; @@ -254,12 +251,12 @@ class BodyVisitor : public IRVisitor { } tensorcore_candidate_ = true; - IRVisitor::Visit(source); + StmtExprVisitor::VisitExpr(source); } } - void Visit_(const Call* op) final { - IRVisitor::Visit_(op); + void VisitExpr_(const Call* op) final { + StmtExprVisitor::VisitExpr_(op); args_.insert(std::make_pair(op->name, op->args)); } @@ -298,7 +295,7 @@ class ScheduleAnalyser { BodyVisitor body_visitor; for (Expr expr : compute->body) { - body_visitor.Visit(expr); + body_visitor(expr); } if (!body_visitor.tensorcore_candidate_) { continue; @@ -370,12 +367,11 @@ class ScheduleAnalyser { // IndexVisitor visits access index of fragment // to record variable for loop scaling -class IndexVisitor : public IRVisitor { +class IndexVisitor : public StmtExprVisitor { public: IndexVisitor() {} - using IRVisitor::Visit_; - void Visit_(const Variable* op) final { + void VisitExpr_(const Variable* op) final { loop_scaling_.insert(std::make_pair(op, scaling_factor_)); } @@ -389,7 +385,7 @@ class IndexVisitor : public IRVisitor { // BufferAnalyser gets buffer info, // e.g. thread tile and warp tile, for TensorCore CodeGen -class BufferAnalyser : public IRVisitor { +class BufferAnalyser : public StmtExprVisitor { public: explicit BufferAnalyser(Map extern_buffer, const ScheduleAnalyser &schedule_analyser, @@ -407,9 +403,8 @@ class BufferAnalyser : public IRVisitor { buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = bi; } } - using IRVisitor::Visit_; - void Visit_(const AttrStmt* op) final { + void VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::thread_extent) { if (const IntImm* value = op->value.as()) { thread_extent_.insert( @@ -417,10 +412,10 @@ class BufferAnalyser : public IRVisitor { op->node.as()->var->name_hint, value->value)); } - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } else if (op->attr_key == attr::realize_scope) { storage_scope_[op->node.get()] = op->value.as()->value; - Visit(op->body); + this->VisitStmt(op->body); } else if (op->attr_key == attr::buffer_dim_align) { Tensor tensor = Downcast(op->node); const Call* tuple = op->value.as(); @@ -432,14 +427,14 @@ class BufferAnalyser : public IRVisitor { } vinfo[dim].align_factor = tuple->args[1].as()->value; vinfo[dim].align_offset = tuple->args[2].as()->value; - Visit(op->body); + this->VisitStmt(op->body); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } } - void Visit_(const Provide* op) final { - IRVisitor::Visit_(op); + void VisitStmt_(const Provide* op) final { + StmtExprVisitor::VisitStmt_(op); TensorKey key{op->func, op->value_index}; auto it = buf_map_.find(key); CHECK(it != buf_map_.end()) @@ -503,7 +498,7 @@ class BufferAnalyser : public IRVisitor { } auto index = rel_index[i]; auto simplified_index = ir::Simplify(index); - index_visitor.Visit(simplified_index); + index_visitor(simplified_index); } std::string input_name = simplify_name(bi.name); @@ -550,8 +545,8 @@ class BufferAnalyser : public IRVisitor { } } - void Visit_(const Call* op) final { - IRVisitor::Visit_(op); + void VisitExpr_(const Call* op) final { + StmtExprVisitor::VisitExpr_(op); if (op->call_type == Call::Halide) { TensorKey key{op->func, op->value_index}; auto it = buf_map_.find(key); @@ -606,16 +601,16 @@ class BufferAnalyser : public IRVisitor { } auto index = rel_index[i]; auto simplified_index = ir::Simplify(index); - index_visitor.Visit(simplified_index); + index_visitor(simplified_index); } } } - void Visit_(const Realize* op) final { + void VisitStmt_(const Realize* op) final { TensorKey key{op->func, op->value_index}; if (buf_map_.count(key)) { CHECK(buf_map_.at(key).external); - Visit(op->body); + this->VisitStmt(op->body); } else { // create a buffer entry BufferInfo bi; @@ -653,7 +648,7 @@ class BufferAnalyser : public IRVisitor { bi.shape = shape; buf_map_[key] = bi; - Visit(op->body); + this->VisitStmt(op->body); buf_map_[key].released = true; } } @@ -761,12 +756,12 @@ class BufferAnalyser : public IRVisitor { }; // ThreadIdxMutator does the thread index unification inside a warp -class ThreadIdxMutator : public IRMutator { +class ThreadIdxMutator : public StmtExprMutator { public: explicit ThreadIdxMutator(Expr warp_y): warp_y_(warp_y) {} - Expr Mutate_(const Variable* op, const Expr& olde) final { - Expr expr = IRMutator::Mutate_(op, olde); + Expr VisitExpr_(const Variable* op) final { + Expr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (op != nullptr) { if (op->name_hint == "threadIdx.x") { @@ -788,7 +783,7 @@ class ThreadIdxMutator : public IRMutator { // TensorCoreIRMutator mutates the AST for TensorCore CodeGen // based on tensor core intrinsics -class TensorCoreIRMutator : public IRMutator { +class TensorCoreIRMutator : public StmtExprMutator { public: explicit TensorCoreIRMutator(const ScheduleAnalyser &schedule_analyser, const BufferAnalyser &buffer_analyser) @@ -803,10 +798,10 @@ class TensorCoreIRMutator : public IRMutator { warp_tile_(buffer_analyser.warp_tile_), warp_threads_y_(buffer_analyser.warp_threads_y_) {} - Stmt Mutate_(const Realize* op, const Stmt& s) final { + Stmt VisitStmt_(const Realize* op) final { TensorKey key{op->func, op->value_index}; bounds_[key] = op->bounds; - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (op != nullptr) { if (!frag_reg_.count(key.GetName())) { @@ -833,8 +828,8 @@ class TensorCoreIRMutator : public IRMutator { return stmt; } - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const AttrStmt* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); if (op->attr_key == attr::realize_scope) { auto node = op->node.as(); if (node != nullptr) { @@ -846,7 +841,7 @@ class TensorCoreIRMutator : public IRMutator { CHECK(it != matrix_abc_.end()) << "Cannot find matrix info for " << node->name; auto matrix_abc = "wmma." + it->second; - Stmt body = Mutate(op->body); + Stmt body = this->VisitStmt(op->body); return AttrStmt::make(op->node, op->attr_key, matrix_abc, @@ -856,8 +851,8 @@ class TensorCoreIRMutator : public IRMutator { return stmt; } - Stmt Mutate_(const Provide* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const Provide* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); auto it = mma_sync_.find(op); if (it != mma_sync_.end()) { const auto &operands = it->second; @@ -941,7 +936,7 @@ class TensorCoreIRMutator : public IRMutator { // thread index unification inside a warp Expr warp_y = IntImm::make(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); - Expr mutated_value = thread_idx_mutator.Mutate(op->value); + Expr mutated_value = thread_idx_mutator(op->value); Expr src = Call::make(value->dtype, "&", {mutated_value}, @@ -991,7 +986,7 @@ class TensorCoreIRMutator : public IRMutator { // thread index unification inside a warp Expr warp_y = IntImm::make(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); - dst = thread_idx_mutator.Mutate(dst); + dst = thread_idx_mutator(dst); dst = Call::make(DataType::Handle(), "&", {dst}, @@ -1020,8 +1015,8 @@ class TensorCoreIRMutator : public IRMutator { return stmt; } - Stmt Mutate_(const For* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const For* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (op != nullptr) { auto it = loop_scaling_.find(op->loop_var.get()); @@ -1177,7 +1172,7 @@ Stmt RewriteForTensorCore(Stmt stmt, } MMAMatcher mma_matcher(extern_buffer); - mma_matcher.Visit(stmt); + mma_matcher(stmt); if (!mma_matcher.Matched()) { return stmt; } @@ -1189,12 +1184,12 @@ Stmt RewriteForTensorCore(Stmt stmt, BufferAnalyser buffer_analyser(extern_buffer, schedule_analyser, mma_matcher); - buffer_analyser.Visit(stmt); + buffer_analyser(stmt); if (!buffer_analyser.QualifiedForTensorCore()) { return stmt; } - return TensorCoreIRMutator(schedule_analyser, buffer_analyser).Mutate(stmt); + return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt)); } } // namespace ir diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc index c56944406bd4..9fc87f3a0d6b 100644 --- a/src/pass/unroll_loop.cc +++ b/src/pass/unroll_loop.cc @@ -24,7 +24,7 @@ // Unrolls the loop as in Halide pipeline. #include #include -#include +#include #include #include #include @@ -33,7 +33,7 @@ namespace tvm { namespace ir { -class LoopUnroller : public IRMutator { +class LoopUnroller : public StmtExprMutator { public: explicit LoopUnroller(int auto_max_step, int auto_max_depth, @@ -45,12 +45,12 @@ class LoopUnroller : public IRMutator { explicit_unroll_(explicit_unroll) { } - Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == "pragma_auto_unroll_max_step") { int value = 0; CHECK(arith::GetConstInt(op->value, &value)); std::swap(value, auto_max_step_); - Stmt ret = this->Mutate(op->body); + Stmt ret = this->VisitStmt(op->body); std::swap(value, auto_max_step_); return ret; } else if (op->attr_key == "pragma_unroll_explicit") { @@ -58,16 +58,16 @@ class LoopUnroller : public IRMutator { CHECK(arith::GetConstInt(op->value, &value)); bool explicit_unroll = value; std::swap(explicit_unroll, explicit_unroll_); - Stmt ret = this->Mutate(op->body); + Stmt ret = this->VisitStmt(op->body); std::swap(explicit_unroll, explicit_unroll_); return ret; } else { - return IRMutator::Mutate_(op, stmt); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const For* op, const Stmt& s) { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const For* op) { + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); int value = GetExtent(op); // condition for auto unroll @@ -110,18 +110,18 @@ class LoopUnroller : public IRMutator { } } - Stmt Mutate_(const Store* op, const Stmt& stmt) final { + Stmt VisitStmt_(const Store* op) final { ++step_count_; - return IRMutator::Mutate_(op, stmt); + return StmtExprMutator::VisitStmt_(op); } - Stmt Mutate_(const Evaluate* op, const Stmt& stmt) final { + Stmt VisitStmt_(const Evaluate* op) final { ++step_count_; - return IRMutator::Mutate_(op, stmt); + return StmtExprMutator::VisitStmt_(op); } - Stmt Mutate_(const Block* op, const Stmt& stmt) final { - Stmt first = this->Mutate(op->first); + Stmt VisitStmt_(const Block* op) final { + Stmt first = this->VisitStmt(op->first); // cleanup state int step_count = step_count_; int unroll_depth = unroll_depth_; @@ -130,13 +130,13 @@ class LoopUnroller : public IRMutator { unroll_depth_ = 0; normal_loop_depth_ = 0; // work on rest part - Stmt rest = this->Mutate(op->rest); + Stmt rest = this->VisitStmt(op->rest); step_count_ += step_count; normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_); unroll_depth_ = std::max(unroll_depth_, unroll_depth); if (first.same_as(op->first) && rest.same_as(op->rest)) { - return stmt; + return GetRef(op); } else { return Block::make(first, rest); } @@ -204,7 +204,7 @@ Stmt UnrollLoop(Stmt stmt, auto_max_step, auto_max_depth, auto_max_extent, - explicit_unroll).Mutate(stmt); + explicit_unroll)(stmt); if (!ret.same_as(stmt)) { return ConvertSSA(ret); } else { diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index 94639f7be363..c22243cc8f93 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -23,7 +23,7 @@ // Loop vectorizer as in Halide pipeline. #include #include -#include +#include #include #include #include @@ -54,13 +54,13 @@ inline Expr BroadcastTo(Expr e, int lanes) { // // The same principle applies when using one thread to simulate multiple context. // -class VecAllocAccess : public IRMutator { +class VecAllocAccess : public StmtExprMutator { public: VecAllocAccess(const Variable* buf, Var var, int var_lanes) : buf_(buf), var_(var), var_lanes_(var_lanes) {} // Load - Expr Mutate_(const Load* op, const Expr& e) final { - Expr expr = IRMutator::Mutate_(op, e); + Expr VisitExpr_(const Load* op) final { + Expr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (op->buffer_var.get() == buf_) { return Load::make(op->dtype, op->buffer_var, @@ -71,8 +71,8 @@ class VecAllocAccess : public IRMutator { } } // Store - Stmt Mutate_(const Store* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const Store* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (op->buffer_var.get() == buf_) { return Store::make(op->buffer_var, @@ -93,19 +93,16 @@ class VecAllocAccess : public IRMutator { int var_lanes_; }; -class Vectorizer : public IRMutator { +class Vectorizer : public StmtExprMutator { public: Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) { ramp_ = Ramp::make(0, 1, var_lanes); } - // user mutate from parent. - using IRMutator::Mutate; - Stmt Mutate(Stmt stmt) final { + Stmt VisitStmt(const Stmt& stmt) final { CHECK(!need_scalarize_); - - Stmt ret = IRMutator::Mutate(stmt); + Stmt ret = StmtExprMutator::VisitStmt(stmt); if (need_scalarize_) { need_scalarize_ = false; return Scalarize(stmt); @@ -114,19 +111,18 @@ class Vectorizer : public IRMutator { } } - - Expr Mutate_(const Add* op, const Expr &e) final { - return AddSubVec(op, e); + Expr VisitExpr_(const Add* op) final { + return AddSubVec(op); } - Expr Mutate_(const Sub* op, const Expr &e) final { - return AddSubVec(op, e); + Expr VisitExpr_(const Sub* op) final { + return AddSubVec(op); } - Expr Mutate_(const Mul* op, const Expr &e) final { - Expr a = this->Mutate(op->a); - Expr b = this->Mutate(op->b); + Expr VisitExpr_(const Mul* op) final { + Expr a = this->VisitExpr(op->a); + Expr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return e; + return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); if (lanes != 1) { @@ -143,53 +139,53 @@ class Vectorizer : public IRMutator { } return Mul::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } - return BinaryVec(op, e); + return BinaryVec(op); } - Expr Mutate_(const Div* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const Div* op) final { + return BinaryVec(op); } - Expr Mutate_(const Mod* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const Mod* op) final { + return BinaryVec(op); } - Expr Mutate_(const FloorDiv* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const FloorDiv* op) final { + return BinaryVec(op); } - Expr Mutate_(const FloorMod* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const FloorMod* op) final { + return BinaryVec(op); } - Expr Mutate_(const Min* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const Min* op) final { + return BinaryVec(op); } - Expr Mutate_(const Max* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const Max* op) final { + return BinaryVec(op); } - Expr Mutate_(const EQ* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const EQ* op) final { + return BinaryVec(op); } - Expr Mutate_(const NE* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const NE* op) final { + return BinaryVec(op); } - Expr Mutate_(const LT* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const LT* op) final { + return BinaryVec(op); } - Expr Mutate_(const LE* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const LE* op) final { + return BinaryVec(op); } - Expr Mutate_(const GT* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const GT* op) final { + return BinaryVec(op); } - Expr Mutate_(const GE* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const GE* op) final { + return BinaryVec(op); } - Expr Mutate_(const And* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const And* op) final { + return BinaryVec(op); } - Expr Mutate_(const Or* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const Or* op) final { + return BinaryVec(op); } - Expr Mutate_(const Ramp* op, const Expr &e) final { - Expr base = this->Mutate(op->base); - Expr stride = this->Mutate(op->stride); + Expr VisitExpr_(const Ramp* op) final { + Expr base = this->VisitExpr(op->base); + Expr stride = this->VisitExpr(op->stride); if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) { const Ramp* base_ramp = base.as(); if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op->lanes))) { @@ -208,14 +204,14 @@ class Vectorizer : public IRMutator { } return Shuffle::make_concat(elems); } - Expr Mutate_(const Select *op, const Expr& e) final { - Expr cond = this->Mutate(op->condition); - Expr t = this->Mutate(op->true_value); - Expr f = this->Mutate(op->false_value); + Expr VisitExpr_(const Select *op) final { + Expr cond = this->VisitExpr(op->condition); + Expr t = this->VisitExpr(op->true_value); + Expr f = this->VisitExpr(op->false_value); if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { - return e; + return GetRef(op); } else { int lanes = std::max(std::max( cond.dtype().lanes(), @@ -223,37 +219,37 @@ class Vectorizer : public IRMutator { return Select::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes)); } } - Expr Mutate_(const Cast *op, const Expr& e) final { - Expr value = this->Mutate(op->value); + Expr VisitExpr_(const Cast *op) final { + Expr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { - return e; + return GetRef(op); } else { return Cast::make(op->dtype.with_lanes(value.dtype().lanes()), value); } } // Variable - Expr Mutate_(const Variable* v, const Expr& e) final { + Expr VisitExpr_(const Variable* v) final { if (v == var_.get()) { return ramp_; } else if (lets_.count(v)) { return lets_[v]; } else { - return e; + return GetRef(v); } } // IfThenElse expr - Expr MutateIfThenElseExpr_(const Call *op, const Expr& e) { - Expr cond = this->Mutate(op->args[0]); + Expr MutateIfThenElseExpr_(const Call *op) { + Expr cond = this->VisitExpr(op->args[0]); if (cond.dtype().is_vector()) { need_scalarize_ = true; - return e; + return GetRef(op); } - Expr t = this->Mutate(op->args[1]); - Expr f = this->Mutate(op->args[2]); + Expr t = this->VisitExpr(op->args[1]); + Expr f = this->VisitExpr(op->args[2]); if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) { - return e; + return GetRef(op); } else { int lanes = std::max(t.dtype().lanes(), f.dtype().lanes()); t = BroadcastTo(t, lanes); @@ -264,23 +260,23 @@ class Vectorizer : public IRMutator { } } // Call - Expr Mutate_(const Call* op, const Expr& e) final { + Expr VisitExpr_(const Call* op) final { if (op->name == intrinsic::tvm_if_then_else) { - return MutateIfThenElseExpr_(op, e); + return MutateIfThenElseExpr_(op); } if (!op->is_vectorizable()) { // Cannot vectorize this op Array new_args; for (auto arg : op->args) { - auto new_arg = this->Mutate(arg); + auto new_arg = this->VisitExpr(arg); if (new_arg.dtype().is_vector()) { need_scalarize_ = true; - return e; + return GetRef(op); } new_args.push_back(new_arg); } if (op->args.same_as(new_args)) { - return e; + return GetRef(op); } else { return Call::make( op->dtype, op->name, new_args, op->call_type, op->func, op->value_index); @@ -290,7 +286,7 @@ class Vectorizer : public IRMutator { Array new_args = MutateArray(op->args, &lane); // normal code path. if (op->args.same_as(new_args)) { - return e; + return GetRef(op); } else { return Call::make( op->dtype.with_lanes(lane), op->name, new_args, @@ -299,11 +295,11 @@ class Vectorizer : public IRMutator { } } // Load - Expr Mutate_(const Load* op, const Expr& e) final { - Expr index = this->Mutate(op->index); - Expr pred = this->Mutate(op->predicate); + Expr VisitExpr_(const Load* op) final { + Expr index = this->VisitExpr(op->index); + Expr pred = this->VisitExpr(op->predicate); if (index.same_as(op->index) && pred.same_as(op->predicate)) { - return e; + return GetRef(op); } else { int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes()); return Load::make( @@ -314,42 +310,42 @@ class Vectorizer : public IRMutator { } } // Let - Expr Mutate_(const Let* op, const Expr& e) final { - Expr value = this->Mutate(op->value); + Expr VisitExpr_(const Let* op) final { + Expr value = this->VisitExpr(op->value); CHECK(!lets_.count(op->var.get())) << "not SSA"; if (value.dtype().lanes() != op->value.dtype().lanes()) { Var v(op->var->name_hint, value.dtype()); lets_[op->var.get()] = v; - return Let::make(v, value, Mutate(op->body)); + return Let::make(v, value, this->VisitExpr(op->body)); } else { - Expr body = this->Mutate(op->body); + Expr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return e; + return GetRef(op); } else { return Let::make(op->var, value, body); } } } // Provide - Stmt Mutate_(const Provide* op, const Stmt& s) final { - Expr new_value = this->Mutate(op->value); + Stmt VisitStmt_(const Provide* op) final { + Expr new_value = this->VisitExpr(op->value); int lane = new_value.dtype().lanes(); Array new_args = MutateArray(op->args, &lane); if (op->args.same_as(new_args) && op->value.same_as(new_value)) { - return s; + return GetRef(op); } else { new_value = BroadcastTo(new_value, lane); return Provide::make(op->func, op->value_index, new_value, new_args); } } // Store - Stmt Mutate_(const Store* op, const Stmt& s) final { - Expr value = this->Mutate(op->value); - Expr index = this->Mutate(op->index); - Expr pred = this->Mutate(op->predicate); + Stmt VisitStmt_(const Store* op) final { + Expr value = this->VisitExpr(op->value); + Expr index = this->VisitExpr(op->index); + Expr pred = this->VisitExpr(op->predicate); if (value.same_as(op->value) && index.same_as(op->index)) { - return s; + return GetRef(op); } else { int lanes = std::max(value.dtype().lanes(), index.dtype().lanes()); lanes = std::max(lanes, pred.dtype().lanes()); @@ -360,20 +356,20 @@ class Vectorizer : public IRMutator { } } // For - Stmt Mutate_(const For* op, const Stmt& s) final { + Stmt VisitStmt_(const For* op) final { if (op->for_type == ForType::Vectorized) { LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring..."; } CHECK(is_zero(op->min)); CHECK(!op->extent.dtype().is_vector()); - Expr extent = Mutate(op->extent); + Expr extent = this->VisitExpr(op->extent); if (extent.dtype().is_vector()) { - return Scalarize(s); + return Scalarize(GetRef(op)); } - Stmt body = Mutate(op->body); + Stmt body = this->VisitStmt(op->body); if (extent.same_as(op->extent) && body.same_as(op->body)) { - return s; + return GetRef(op); } else { return For::make( op->loop_var, op->min, extent, @@ -381,47 +377,47 @@ class Vectorizer : public IRMutator { } } // IfThenElse - Stmt Mutate_(const IfThenElse* op, const Stmt& s) final { + Stmt VisitStmt_(const IfThenElse* op) final { CHECK(!op->condition.dtype().is_vector()); - Expr condition = this->Mutate(op->condition); + Expr condition = this->VisitExpr(op->condition); if (condition.dtype().is_vector()) { - return Scalarize(s); + return Scalarize(GetRef(op)); } - Stmt then_case = this->Mutate(op->then_case); + Stmt then_case = this->VisitStmt(op->then_case); Stmt else_case; if (op->else_case.defined()) { - else_case = this->Mutate(op->else_case); + else_case = this->VisitStmt(op->else_case); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { - return s; + return GetRef(op); } else { return IfThenElse::make(condition, then_case, else_case); } } // LetStmt - Stmt Mutate_(const LetStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const LetStmt* op) final { LOG(WARNING) << "Cannot vectorize with LetStmt, remove it with Simplify Before Vectorize"; - return Scalarize(s); + return Scalarize(GetRef(op)); } // Allocate - Stmt Mutate_(const Allocate* op, const Stmt& s) final { + Stmt VisitStmt_(const Allocate* op) final { if (op->new_expr.defined()) { LOG(WARNING) << "Cannot vectorize with new expr"; - return Scalarize(s); + return Scalarize(GetRef(op)); } - Expr condition = Mutate(op->condition); + Expr condition = this->VisitExpr(op->condition); if (condition.dtype().is_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc "; - return Scalarize(s); + return Scalarize(GetRef(op)); } Array extents; for (size_t i = 0; i < op->extents.size(); i++) { - Expr new_ext = Mutate(op->extents[i]); + Expr new_ext = this->VisitExpr(op->extents[i]); if (new_ext.dtype().is_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc "; - return Scalarize(s); + return Scalarize(GetRef(op)); } extents.push_back(new_ext); } @@ -429,8 +425,8 @@ class Vectorizer : public IRMutator { extents.push_back(var_lanes_); // rewrite access to buffer internally. Stmt body = VecAllocAccess( - op->buffer_var.get(), var_, var_lanes_).Mutate(op->body); - body = Mutate(body); + op->buffer_var.get(), var_, var_lanes_)(op->body); + body = this->VisitStmt(body); return Allocate::make( op->buffer_var, op->dtype, extents, condition, body, @@ -466,7 +462,7 @@ class Vectorizer : public IRMutator { std::vector new_arr(arr.size()); for (size_t i = 0; i < arr.size(); i++) { Expr old_elem = arr[i]; - Expr new_elem = this->Mutate(old_elem); + Expr new_elem = this->VisitExpr(old_elem); if (!new_elem.same_as(old_elem)) changed = true; new_arr[i] = new_elem; lanes = std::max(lanes, new_elem.dtype().lanes()); @@ -482,24 +478,24 @@ class Vectorizer : public IRMutator { return Array(new_arr); } template - Expr BinaryVec(const T* op, const Expr& e) { - Expr a = this->Mutate(op->a); - Expr b = this->Mutate(op->b); + Expr BinaryVec(const T* op) { + Expr a = this->VisitExpr(op->a); + Expr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return e; + return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } } template - Expr AddSubVec(const T* op, const Expr& e) { - Expr a = this->Mutate(op->a); - Expr b = this->Mutate(op->b); + Expr AddSubVec(const T* op) { + Expr a = this->VisitExpr(op->a); + Expr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return e; + return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); if (lanes != 1) { @@ -521,9 +517,9 @@ class Vectorizer : public IRMutator { } }; -class LoopVectorizer : public IRMutator { +class LoopVectorizer : public StmtMutator { public: - Stmt Mutate_(const For* op, const Stmt& s) final { + Stmt VisitStmt_(const For* op) final { if (op->for_type == ForType::Vectorized) { CHECK(is_zero(op->min)); int lanes = 0; @@ -531,21 +527,21 @@ class LoopVectorizer : public IRMutator { if (!succ || lanes < 1) { LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent; } - return Vectorizer(op->loop_var, lanes).Mutate(op->body); + return Vectorizer(op->loop_var, lanes)(op->body); } else { - return IRMutator::Mutate_(op, s); + return StmtMutator::VisitStmt_(op); } } }; Stmt VectorizeLoop(Stmt stmt) { - return LoopVectorizer().Mutate(stmt); + return LoopVectorizer()(std::move(stmt)); } -class VectorizeSkipper : public IRMutator { +class VectorizeSkipper : public StmtMutator { public: - Stmt Mutate_(const For* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const For* op) final { + Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); if (op->for_type == ForType::Vectorized) { return For::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, @@ -557,7 +553,7 @@ class VectorizeSkipper : public IRMutator { }; Stmt SkipVectorize(Stmt stmt) { - return VectorizeSkipper().Mutate(stmt); + return VectorizeSkipper()(std::move(stmt)); } } // namespace ir diff --git a/src/pass/verify_compact_buffer.cc b/src/pass/verify_compact_buffer.cc index c2131f6c7687..671b4a073117 100644 --- a/src/pass/verify_compact_buffer.cc +++ b/src/pass/verify_compact_buffer.cc @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include @@ -32,15 +32,15 @@ namespace tvm { namespace ir { -class VerifyBuffer : public IRVisitor { +class VerifyBuffer : public StmtVisitor { public: bool Verify(const Stmt& stmt) { - this->Visit(stmt); + this->VisitStmt(stmt); return is_compact_; } - void Visit_(const AttrStmt* op) final { - IRVisitor::Visit_(op); + void VisitStmt_(const AttrStmt* op) final { + StmtVisitor::VisitStmt_(op); if (op->attr_key == attr::buffer_bind_scope) { is_compact_ = true; } diff --git a/src/pass/verify_gpu_code.cc b/src/pass/verify_gpu_code.cc index 49a05345a99f..1adc6851caf4 100644 --- a/src/pass/verify_gpu_code.cc +++ b/src/pass/verify_gpu_code.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -26,12 +26,12 @@ #include #include -#include +#include namespace tvm { namespace ir { -class GPUCodeVerifier : public IRVisitor { +class GPUCodeVerifier : public StmtVisitor { public: bool Verify(tvm::Stmt stmt, int64_t max_local_memory_per_block, @@ -49,12 +49,12 @@ class GPUCodeVerifier : public IRVisitor { Reset_(); - this->Visit(stmt); + this->VisitStmt(stmt); return valid_; } - void Visit_(const ProducerConsumer *op) { + void VisitStmt_(const ProducerConsumer* op) final { if (nest_level_ == 0) { // enter a new kernel, reset statistics Reset_(); @@ -62,10 +62,10 @@ class GPUCodeVerifier : public IRVisitor { if (op->is_producer) { nest_level_++; - IRVisitor::Visit_(op); + StmtVisitor::VisitStmt_(op); nest_level_--; } else { - IRVisitor::Visit_(op); + StmtVisitor::VisitStmt_(op); } if (nest_level_ == 0) { @@ -77,8 +77,8 @@ class GPUCodeVerifier : public IRVisitor { } } - void Visit_(const Allocate *op) { - IRVisitor::Visit_(op); + void VisitStmt_(const Allocate* op) final { + StmtVisitor::VisitStmt_(op); // visit an allocation of a buffer in shared memory, record its size if (visited_local_buffers_.count(op->buffer_var.get()) != 0) { size_t size = static_cast(op->constant_allocation_size()); @@ -89,7 +89,7 @@ class GPUCodeVerifier : public IRVisitor { } } - void Visit_(const AttrStmt *op) { + void VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::storage_scope) { std::string op_value = op->value.as()->value; if (op_value == "local") { @@ -132,7 +132,7 @@ class GPUCodeVerifier : public IRVisitor { } } } - IRVisitor::Visit_(op); + StmtVisitor::VisitStmt_(op); } private: diff --git a/src/pass/verify_memory.cc b/src/pass/verify_memory.cc index 4a5c8adeb8e7..415841d04017 100644 --- a/src/pass/verify_memory.cc +++ b/src/pass/verify_memory.cc @@ -22,8 +22,9 @@ * \brief Pass to check if memory accesses are legal. */ #include -#include #include +#include + namespace tvm { namespace ir { @@ -39,7 +40,7 @@ namespace { * This pass performs such verification by checking if all Producer/Consumer * with memory accesses are bound with threads when device type is GPU. */ -class MemoryAccessVerifier final : protected IRVisitor { +class MemoryAccessVerifier final : protected StmtExprVisitor { public: /// Special member functions //@{ @@ -55,7 +56,7 @@ class MemoryAccessVerifier final : protected IRVisitor { /// Interface to perform memory access verification void Run() { if (!IsGPUDevice(dev_type_) && !IsFPGADevice(dev_type_)) return; - IRVisitor::Visit(func_->body); + StmtExprVisitor::VisitStmt(func_->body); } /// Verification result @@ -64,42 +65,47 @@ class MemoryAccessVerifier final : protected IRVisitor { protected: /// Visitor implementation //@{ - void Visit(const ObjectRef &n) final { + void VisitExpr(const Expr &n) final { + if (Failed()) return; + StmtExprVisitor::VisitExpr(n); + } + + void VisitStmt(const Stmt &n) final { if (Failed()) return; - IRVisitor::Visit(n); + StmtExprVisitor::VisitStmt(n); } - void Visit_(const LetStmt *op) final { + void VisitStmt_(const LetStmt* op) final { // Book keep definitions defs_[op->var.get()] = op->value; - return IRVisitor::Visit_(op); + return StmtExprVisitor::VisitStmt_(op); } - void Visit_(const AttrStmt *op) final { + void VisitStmt_(const AttrStmt* op) final { if (!InThreadEnv() && (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope)) { EnterThreadEnv(); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); ExitThreadEnv(); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } } - void Visit_(const ProducerConsumer *op) final { + void VisitStmt_(const ProducerConsumer* op) final { EnterProducerConsumer(op); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); ExitProducerConsumer(); } - void Visit_(const Load *op) final { + void VisitExpr_(const Load* op) final { HandleLoadStoreToVariable(op->buffer_var); - return IRVisitor::Visit_(op); + return StmtExprVisitor::VisitExpr_(op); } - void Visit_(const Store *op) final { + void VisitStmt_(const Store* op) final { HandleLoadStoreToVariable(op->buffer_var); - return IRVisitor::Visit_(op); + return StmtExprVisitor::VisitStmt_(op); } //@} diff --git a/src/schedule/auto_inline_elem_wise.cc b/src/schedule/auto_inline_elem_wise.cc index e587f385734f..f5cd6c183737 100644 --- a/src/schedule/auto_inline_elem_wise.cc +++ b/src/schedule/auto_inline_elem_wise.cc @@ -22,23 +22,23 @@ */ #include #include -#include +#include namespace tvm { namespace schedule { using namespace ir; -class ElemWiseDetector : public ir::IRVisitor { +class ElemWiseDetector : public ir::ExprVisitor { public: explicit ElemWiseDetector(Array axis) : axis_(axis) {} - void Visit(const ObjectRef& e) final { + void VisitExpr(const Expr& e) final { if (!is_elem_wise_) return; - IRVisitor::Visit(e); + ExprVisitor::VisitExpr(e); } - void Visit_(const Call* op) final { + void VisitExpr_(const Call* op) final { Array axis = op->args; if (axis_.size() != axis.size()) { is_elem_wise_ = false; @@ -51,7 +51,7 @@ class ElemWiseDetector : public ir::IRVisitor { return; } } - IRVisitor::Visit_(op); + ExprVisitor::VisitExpr_(op); } bool is_elem_wise_{true}; @@ -64,7 +64,7 @@ class ElemWiseDetector : public ir::IRVisitor { bool IsElemWise(const Operation& op) { if (const ComputeOpNode* compute = op.as()) { ElemWiseDetector v = ElemWiseDetector(compute->axis); - for (auto& e : compute->body) v.Visit(e); + for (auto& e : compute->body) v(e); return v.is_elem_wise_; } return false; diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index d4baded91f7c..7cf5cff0aff7 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -21,7 +21,6 @@ * \file bound.cc * \brief The bound inference logic. */ -#include #include #include #include diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index c3024a71977f..a5ed43601024 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -22,7 +22,7 @@ * \brief Utilities to get information about schedule graph. */ #include -#include +#include #include #include #include diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 70a73abc4698..9aef563fbefc 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -22,7 +22,7 @@ */ #include #include -#include +#include #include #include #include "message_passing.h" @@ -42,24 +42,24 @@ size_t FindNodeRef(ArrayNode* array_node, const T& v) { } // The replacer of cache. -class VarReplacer : public ir::IRMutator { +class VarReplacer : public ir::StmtExprMutator { public: explicit VarReplacer( const std::unordered_map& vsub) : vsub_(vsub) {} - Expr Mutate_(const Variable* op, const Expr& e) { + Expr VisitExpr_(const Variable* op) final { auto it = vsub_.find(op); if (it != vsub_.end()) return it->second; - return e; + return GetRef(op); } ir::CommReducer MutateCommReducer(ir::CommReducer combiner) { // Replace free variables in combiner auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const Expr& e) { - return this->Mutate(e); + return this->VisitExpr(e); }); auto new_result = ir::UpdateArray(combiner->result, [this] (const Expr& e) { - return this->Mutate(e); + return this->VisitExpr(e); }); if (combiner->identity_element.same_as(new_identity) && @@ -71,8 +71,8 @@ class VarReplacer : public ir::IRMutator { } } - Expr Mutate_(const ir::Reduce* op, const Expr& e) { - Expr new_e = IRMutator::Mutate_(op, e); + Expr VisitExpr_(const ir::Reduce* op) final { + Expr new_e = StmtExprMutator::VisitExpr_(op); const ir::Reduce* new_reduce = new_e.as(); ir::CommReducer new_combiner = MutateCommReducer(op->combiner); if (op->combiner.same_as(new_combiner)) { @@ -316,9 +316,9 @@ Array CacheWriteWithReLayout(Schedule sch, Array body_list; const ir::Reduce* first_reduce = nullptr; for (auto cbody : compute->body) { - body = VarReplacer(vsub).Mutate(cbody); + body = VarReplacer(vsub)(cbody); body = InjectPredicate(predicates, body); - body = VarReplacer(vsub2newvar).Mutate(body); + body = VarReplacer(vsub2newvar)(body); // Reduce nodes in ONE computeOp must be the same except value_index // This is right only if the original body ensures Reduce nodes are the same if (body->IsInstance()) { @@ -404,8 +404,8 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, for (Region old_region : tensor_op->input_regions) { Region region; for (Range r : old_region) { - Expr min = VarReplacer(vsub2newvar).Mutate(r->min); - Expr extent = VarReplacer(vsub2newvar).Mutate(r->extent); + Expr min = VarReplacer(vsub2newvar)(r->min); + Expr extent = VarReplacer(vsub2newvar)(r->extent); region.push_back(Range::make_by_min_extent(min, extent)); } new_regions.push_back(region); @@ -413,7 +413,7 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, Array new_scalar_inputs; for (Expr old_input : tensor_op->scalar_inputs) { - new_scalar_inputs.push_back(VarReplacer(vsub2newvar).Mutate(old_input)); + new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input)); } Operation cache_op = TensorComputeOpNode::make( @@ -786,9 +786,9 @@ Array Schedule::rfactor(const Tensor& tensor, } VarReplacer replacer(vsub); Array new_source = ir::UpdateArray(reduce->source, - [&replacer] (const Expr& e) { return replacer.Mutate(e); }); + [&replacer] (const Expr& e) { return replacer(e); }); - Expr new_pred = replacer.Mutate(predicate); + Expr new_pred = replacer(predicate); std::vector body; for (size_t idx = 0; idx < reduce->source.size(); ++idx) { diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index ec73c67bedff..91d3726f0bab 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -22,7 +22,6 @@ */ #include #include -#include #include #include "graph.h" diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index 0103410e6132..b177d6f8d22f 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -21,9 +21,8 @@ * \file schedule_ops.cc */ #include -#include #include -#include +#include #include #include #include @@ -71,7 +70,7 @@ Stmt MakePipeline(const Stage& s, } // inject the operator's realization on the stmt. -class InjectAttach : public IRMutator { +class InjectAttach : public StmtMutator { public: InjectAttach(const Stage& stage, const Stage& attach_spec, @@ -80,9 +79,9 @@ class InjectAttach : public IRMutator { : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map), debug_keep_trivial_loop_(debug_keep_trivial_loop) {} - Stmt Mutate(Stmt stmt) final { - CHECK(stmt.defined()); - stmt = IRMutator::Mutate(stmt); + Stmt VisitStmt(const Stmt& input_stmt) final { + CHECK(input_stmt.defined()); + auto stmt = StmtMutator::VisitStmt(input_stmt); const AttrStmt* op = stmt.as(); if (op != nullptr && op->attr_key == attr::loop_scope) { @@ -115,7 +114,7 @@ class InjectAttach : public IRMutator { }; // inject the operator's realization on the stmt. -class InjectScanStep : public IRMutator { +class InjectScanStep : public StmtMutator { public: InjectScanStep(const Stage& stage, const Operation& scan_op, @@ -125,9 +124,9 @@ class InjectScanStep : public IRMutator { : stage_(stage), scan_op_(scan_op), dom_map_(dom_map), is_init_(is_init), debug_keep_trivial_loop_(debug_keep_trivial_loop) {} - Stmt Mutate(Stmt stmt) final { - CHECK(stmt.defined()); - stmt = IRMutator::Mutate(stmt); + Stmt VisitStmt(const Stmt& input_stmt) final { + CHECK(input_stmt.defined()); + auto stmt = StmtMutator::VisitStmt(input_stmt); // update const AttrStmt* op = stmt.as(); if (op != nullptr && @@ -161,12 +160,12 @@ class InjectScanStep : public IRMutator { // Postprocessing of schedule op // Replace the init and update's expression by scan's buffer. -class SchedulePostProc : public IRMutator { +class SchedulePostProc : public StmtExprMutator { public: - Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final { + Stmt VisitStmt_(const ProducerConsumer* op) final { auto it = replace_op_.find(op->func.get()); if (it != replace_op_.end()) { - Stmt body = this->Mutate(op->body); + Stmt body = this->VisitStmt(op->body); if (it->second.defined()) { return ProducerConsumer::make( it->second, op->is_producer, body); @@ -174,36 +173,36 @@ class SchedulePostProc : public IRMutator { return body; } } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const LetStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const LetStmt* op) final { if (!HasSideEffect(op->value)) { - var_value_[op->var.get()] = Mutate(op->value); - return this->Mutate(op->body); + var_value_[op->var.get()] = this->VisitExpr(op->value); + return this->VisitStmt(op->body); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::loop_scope || op->attr_key == attr::scan_init_scope) { - return this->Mutate(op->body); + return this->VisitStmt(op->body); } else if (op->attr_key == attr::scan_update_scope) { const ScanOpNode* scan = op->node.as(); CHECK(scan); var_value_[scan->scan_axis->var.get()] = op->value; - return this->Mutate(op->body); + return this->VisitStmt(op->body); } else if (op->attr_key == attr::thread_extent) { // delete duplicated thread extent attr auto it = thread_extent_scope_.find(op->node.get()); if (it != thread_extent_scope_.end()) { CHECK(is_zero(ir::Simplify(it->second - op->value))); - return this->Mutate(op->body); + return this->VisitStmt(op->body); } else { thread_extent_scope_[op->node.get()] = op->value; - Stmt ret = IRMutator::Mutate_(op, s); + Stmt ret = StmtExprMutator::VisitStmt_(op); thread_extent_scope_.erase(op->node.get()); return ret; } @@ -214,9 +213,9 @@ class SchedulePostProc : public IRMutator { if (it->second.defined()) { Stmt ret = AttrStmt::make( it->second, op->attr_key, op->value, op->body); - return this->Mutate(ret); + return this->VisitStmt(ret); } else { - return this->Mutate(op->body); + return this->VisitStmt(op->body); } } } else if (op->attr_key == ir::attr::buffer_bind_scope) { @@ -227,9 +226,9 @@ class SchedulePostProc : public IRMutator { if (it->second.defined()) { return AttrStmt::make( Array{tuple[0], it->second.output(tensor->value_index)}, - op->attr_key, op->value, Mutate(op->body)); + op->attr_key, op->value, this->VisitStmt(op->body)); } else { - return this->Mutate(op->body); + return this->VisitStmt(op->body); } } } else if (op->attr_key == ir::attr::buffer_dim_align) { @@ -239,16 +238,16 @@ class SchedulePostProc : public IRMutator { if (it->second.defined()) { return AttrStmt::make( it->second.output(tensor->value_index), - op->attr_key, op->value, Mutate(op->body)); + op->attr_key, op->value, this->VisitStmt(op->body)); } else { - return this->Mutate(op->body); + return this->VisitStmt(op->body); } } } - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } - Stmt Mutate_(const Realize* op, const Stmt& s) final { + Stmt VisitStmt_(const Realize* op) final { TensorKey key{op->func, op->value_index}; auto it = replace_realize_.find(key); if (it != replace_realize_.end()) { @@ -256,29 +255,29 @@ class SchedulePostProc : public IRMutator { Stmt ret = Realize::make( it->second->op, it->second->value_index, op->dtype, op->bounds, op->condition, op->body); - return this->Mutate(ret); + return this->VisitStmt(ret); } else { - return this->Mutate(op->body); + return this->VisitStmt(op->body); } } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const Provide* op, const Stmt& s) final { + Stmt VisitStmt_(const Provide* op) final { TensorKey key{op->func, op->value_index}; auto it = replace_buffer_.find(key); if (it != replace_buffer_.end()) { const Tensor& dst = it->second; Stmt ret = Provide::make( dst->op, dst->value_index, op->value, op->args); - return this->Mutate(ret); + return this->VisitStmt(ret); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Expr Mutate_(const Call* op, const Expr& e) final { + Expr VisitExpr_(const Call* op) final { if (op->call_type == Call::Halide) { TensorKey key{op->func, op->value_index}; auto it = replace_buffer_.find(key); @@ -287,18 +286,18 @@ class SchedulePostProc : public IRMutator { Expr ret = Call::make( op->dtype, dst->op->name, op->args, op->call_type, dst->op, dst->value_index); - return this->Mutate(ret); + return this->VisitExpr(ret); } } - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } - Expr Mutate_(const Variable* op, const Expr& e) final { + Expr VisitExpr_(const Variable* op) final { auto it = var_value_.find(op); if (it != var_value_.end()) { return it->second; } else { - return e; + return GetRef(op); } } @@ -392,14 +391,14 @@ Stmt ScheduleOps( if (scan_init.count(s->op)) { CHECK(body.defined()); InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop); - body = mu.Mutate(body); + body = mu(std::move(body)); CHECK(mu.found_attach) << "did not find attachment point for scan.init"; } else if (attach_spec->attach_type == kScanUpdate) { // Handle scan update CHECK(body.defined()); InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop); - body = mu.Mutate(body); + body = mu(std::move(body)); CHECK(mu.found_attach) << "did not find attachment point for scan.update"; } else if (attach_spec->attach_type == kInlinedAlready) { @@ -411,7 +410,7 @@ Stmt ScheduleOps( CHECK_EQ(attach_spec->attach_type, kScope); CHECK(body.defined()); InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop); - body = mutator.Mutate(body); + body = mutator(std::move(body)); CHECK(mutator.found_attach) << "did not find attachment point for " << s << " in " << attach_spec->attach_stage->op << " x " << attach_spec->attach_ivar @@ -421,7 +420,7 @@ Stmt ScheduleOps( } SchedulePostProc post_proc; post_proc.Init(sch); - return post_proc.Mutate(body); + return post_proc(std::move(body)); } } // namespace schedule diff --git a/tests/cpp/ir_visitor_test.cc b/tests/cpp/ir_visitor_test.cc index 4282a0026ee6..1f34b2549d0d 100644 --- a/tests/cpp/ir_visitor_test.cc +++ b/tests/cpp/ir_visitor_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include TEST(IRVisitor, CountVar) {