From 1a62c23c573e41c84834e9edf4046c874221068a Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 1 Jan 2020 09:56:03 -0800 Subject: [PATCH 01/45] CombineContextCall --- src/pass/combine_context_call.cc | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/pass/combine_context_call.cc b/src/pass/combine_context_call.cc index e050fee98e67..5f35d2c6c3b6 100644 --- a/src/pass/combine_context_call.cc +++ b/src/pass/combine_context_call.cc @@ -23,7 +23,7 @@ * \file combine_context_call.cc */ #include -#include +#include #include #include @@ -32,7 +32,7 @@ namespace ir { // Calculate the statistics of packed function. // These information are needed during codegen. -class ContextCallCombiner final : public IRMutator { +class ContextCallCombiner final : public StmtExprMutator { public: struct CompareExpr { bool operator()(const Expr& lhs, const Expr& rhs) const { @@ -40,7 +40,7 @@ class ContextCallCombiner final : public IRMutator { } }; - Expr Mutate_(const Call* op, const Expr& e) final { + Expr VisitExpr_(const Call* op) final { if (op->is_intrinsic(intrinsic::tvm_thread_context)) { CHECK_EQ(op->args.size(), 1U); Expr ctx = op->args[0]; @@ -60,39 +60,39 @@ class ContextCallCombiner final : public IRMutator { return std::move(ctx_var); } } else { - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } } - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::thread_extent || op->attr_key == attr::coproc_uop_scope) { // Map of comparison expression to variable std::map temp; std::swap(temp, ctx_map_); - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); std::swap(temp, ctx_map_); return BuildContext(temp, stmt); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const For* op, const Stmt& s) final { + Stmt VisitStmt_(const For* op) final { if (op->for_type == ForType::Parallel) { // Map of comparison expression to variable std::map temp; std::swap(temp, ctx_map_); - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); std::swap(temp, ctx_map_); return BuildContext(temp, stmt); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } Stmt Combine(Stmt stmt) { - return BuildContext(ctx_map_, this->Mutate(stmt)); + return BuildContext(ctx_map_, this->VisitStmt(stmt)); } private: From a33ee1c70348d796c7f14639f2284b37e9e0fc1e Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 1 Jan 2020 10:02:36 -0800 Subject: [PATCH 02/45] Migrate BoundChecker --- src/pass/bound_checker.cc | 44 +++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/pass/bound_checker.cc b/src/pass/bound_checker.cc index 648302e9740a..d3898a2ecac6 100644 --- a/src/pass/bound_checker.cc +++ b/src/pass/bound_checker.cc @@ -23,9 +23,8 @@ // Instrument checkers for out of the bounds access. #include -#include #include -#include +#include #include #include #include @@ -33,48 +32,48 @@ namespace tvm { namespace ir { -class BoundCollector : public IRVisitor { +class BoundCollector : public StmtVisitor { public: BoundCollector() {} - void Visit_(const AttrStmt *op) { + void VisitStmt_(const AttrStmt* op) final { if (op->attr_key == ir::attr::buffer_bound) { if (const Variable *key = op->node.as()) { mem_to_shape[key] = op->value; } } - IRVisitor::Visit_(op); + StmtVisitor::VisitStmt_(op); } // Hashtable which maps buffer_var to shape. std::unordered_map mem_to_shape; }; -class BoundChecker : public IRMutator { +class BoundChecker : public StmtExprMutator { public: explicit BoundChecker( const std::unordered_map &mem_to_shape) : mem_to_shape_(mem_to_shape) {} - Stmt Mutate_(const Allocate *op, const Stmt &s) final { + Stmt VisitStmt_(const Allocate* op) final { // If the shape was updated we should update the hashtable. if (UpdateIsNeeded(op->buffer_var)) { Update(op->buffer_var, op->extents, op->dtype); } - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } - Expr Mutate_(const Call *op, const Expr &ex) final { + Expr VisitExpr_(const Call* op) final { if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) { unsafe_rewritten_ = true; } - return IRMutator::Mutate_(op, ex); + return StmtExprMutator::VisitExpr_(op); } - Stmt Mutate_(const Store *op, const Stmt &s) final { + Stmt VisitStmt_(const Store* op) final { store_scope_bound_collector_.clear(); process_store_ = true; unsafe_rewritten_ = false; - IRMutator::Mutate_(op, s); + StmtExprMutator::VisitStmt_(op); process_store_ = false; if (CanInstrument(op->index, op->buffer_var)) { Collect(op->index, op->buffer_var); @@ -92,23 +91,24 @@ class BoundChecker : public IRMutator { return body; } } - return s; + return GetRef(op); } - Expr Mutate_(const Load *op, const Expr &ex) final { + Expr VisitExpr_(const Load* op) final { if (CanInstrument(op->index, op->buffer_var)) { Collect(op->index, op->buffer_var); } - return IRMutator::Mutate_(op, ex); + return StmtExprMutator::VisitExpr_(op); } private: - bool UpdateIsNeeded(const VarExpr &buffer_var) const { + bool UpdateIsNeeded(const VarExpr& buffer_var) const { return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get())); } - void Update(const VarExpr &buffer_var, const Array &new_shape, - const DataType &type) { + void Update(const VarExpr& buffer_var, + const Array& new_shape, + const DataType& type) { // Sanity check at first. if (!new_shape.size()) { return; @@ -132,7 +132,7 @@ class BoundChecker : public IRMutator { mem_to_shape_[buffer_var.get()] = shape; } - bool IndexIsValid(const Expr &index) const { + bool IndexIsValid(const Expr& index) const { if (!index.defined()) { return false; } @@ -146,7 +146,7 @@ class BoundChecker : public IRMutator { return true; } - bool CanInstrument(const Expr &index, const VarExpr &buffer_var) const { + bool CanInstrument(const Expr& index, const VarExpr& buffer_var) const { return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && IndexIsValid(index) && !unsafe_rewritten_; } @@ -206,8 +206,8 @@ class BoundChecker : public IRMutator { Stmt InstrumentBoundCheckers(Stmt stmt) { BoundCollector bound_collector; // At first walk recursively and collect bound attributes. - bound_collector.Visit(stmt); - return BoundChecker(bound_collector.mem_to_shape).Mutate(stmt); + bound_collector(stmt); + return BoundChecker(bound_collector.mem_to_shape)(std::move(stmt)); } } // namespace ir } // namespace tvm From 3bbc29b50b918e91e868ade69ee4ada3fa9c4980 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 1 Jan 2020 10:12:57 -0800 Subject: [PATCH 03/45] Migrate CoprocSync --- src/pass/coproc_sync.cc | 65 ++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/src/pass/coproc_sync.cc b/src/pass/coproc_sync.cc index a5b3285f7fa9..aed94a58da0d 100644 --- a/src/pass/coproc_sync.cc +++ b/src/pass/coproc_sync.cc @@ -22,8 +22,7 @@ */ #include #include -#include -#include +#include #include #include #include "ir_util.h" @@ -33,25 +32,25 @@ namespace tvm { namespace ir { // Visitor to find touched set by co-processor scope. -class CoProcTouchedBuffer : public IRVisitor { +class CoProcTouchedBuffer : public StmtExprVisitor { public: - void Visit_(const Load* op) final { + void VisitExpr_(const Load* op) final { if (in_scope_) { touched_[op->buffer_var.get()].coproc = true; } else { touched_[op->buffer_var.get()].normal = true; } - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); } - void Visit_(const Store* op) final { + void VisitStmt_(const Store* op) final { if (in_scope_) { touched_[op->buffer_var.get()].coproc = true; } else { touched_[op->buffer_var.get()].normal = true; } - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } - void Visit_(const Call* op) final { + void VisitExpr_(const Call* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { const Variable* buffer = op->args[1].as(); if (in_scope_) { @@ -60,17 +59,17 @@ class CoProcTouchedBuffer : public IRVisitor { touched_[buffer].normal = true; } } - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); } - void Visit_(const AttrStmt* op) final { + void VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::coproc_scope && !in_scope_) { in_scope_ = true; IterVar iv = Downcast(op->node); coproc_.insert(iv); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); in_scope_ = false; } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } } @@ -218,12 +217,12 @@ class CoProcBarrierDetector : public StorageAccessVisitor { write_barrier_name_ = coproc_name + ".coproc_write_barrier"; } - void PlanReadBarrier(Stmt stmt) { + void PlanReadBarrier(const Stmt& stmt) { read_barrier_ = true; this->Visit(stmt); PlanReadBarrier(scope_.back(), nullptr); } - void PlanWriteBarrier(Stmt stmt) { + void PlanWriteBarrier(const Stmt& stmt) { read_barrier_ = false; this->Visit(stmt); PlanWriteBarrier(scope_.back(), nullptr); @@ -356,7 +355,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor { }; -class CoProcInstDepDetector : public IRVisitor { +class CoProcInstDepDetector : public StmtVisitor { public: explicit CoProcInstDepDetector( const IterVar& coproc_axis, @@ -366,15 +365,15 @@ class CoProcInstDepDetector : public IRVisitor { sync_pop_name_ = coproc_name + ".coproc_dep_pop"; } - void Plan(Stmt stmt) { - this->Visit(stmt); + void Plan(const Stmt& stmt) { + this->VisitStmt(stmt); if (last_state_.node != nullptr) { MatchFixEnterPop(first_state_); MatchFixExitPush(last_state_); } } - void Visit_(const AttrStmt* op) final { + void VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::coproc_scope && op->node.same_as(coproc_axis_)) { const IntImm* ctx_id = op->value.as(); @@ -385,15 +384,15 @@ class CoProcInstDepDetector : public IRVisitor { curr_state_.exit_ctx.insert(ctx_id->value); UpdateState(); } else { - IRVisitor::Visit_(op); + StmtVisitor::VisitStmt_(op); } } - void Visit_(const For* op) final { + void VisitStmt_(const For* op) final { SyncState temp_first, temp_last; std::swap(first_state_, temp_first); std::swap(last_state_, temp_last); - this->Visit(op->body); + this->VisitStmt(op->body); curr_state_.clear(); if (last_state_.node != nullptr) { curr_state_.node = op; @@ -412,13 +411,13 @@ class CoProcInstDepDetector : public IRVisitor { } } - void Visit_(const IfThenElse* op) final { + void VisitStmt_(const IfThenElse* op) final { SyncState temp_first, temp_last, curr_state; std::swap(first_state_, temp_first); std::swap(last_state_, temp_last); { // then stmt - this->Visit(op->then_case); + this->VisitStmt(op->then_case); if (last_state_.node != nullptr) { curr_state.node = op; MatchFixEnterPop(first_state_); @@ -434,7 +433,7 @@ class CoProcInstDepDetector : public IRVisitor { last_state_.clear(); } if (op->else_case.defined()) { - this->Visit(op->else_case); + this->VisitStmt(op->else_case); if (last_state_.node != nullptr) { curr_state.node = op; MatchFixEnterPop(first_state_); @@ -606,11 +605,11 @@ class CoProcInstDepDetector : public IRVisitor { }; -class CoProcSyncInserter : public IRMutator { +class CoProcSyncInserter : public StmtMutator { public: Stmt Insert(Stmt stmt) { CoProcTouchedBuffer visitor; - visitor.Visit(stmt); + visitor(stmt); if (visitor.coproc_.size() == 0) return stmt; std::unordered_set touched; @@ -652,10 +651,10 @@ class CoProcSyncInserter : public IRMutator { auto& vec = insert_after_[kv.first]; vec.insert(vec.end(), kv.second.begin(), kv.second.end()); } - return Mutate(stmt); + return operator()(std::move(stmt)); } - Stmt Mutate(Stmt stmt) final { + Stmt VisitStmt(const Stmt& stmt) final { Stmt before, after; auto it = insert_before_.find(stmt.get()); if (it != insert_before_.end()) { @@ -666,14 +665,14 @@ class CoProcSyncInserter : public IRMutator { if (it != insert_after_.end()) { after = MergeSeq(it->second); } - stmt = IRMutator::Mutate(stmt); + Stmt new_stmt = StmtMutator::VisitStmt(stmt); if (before.defined()) { - stmt = Block::make(before, stmt); + new_stmt = Block::make(before, new_stmt); } if (after.defined()) { - stmt = Block::make(stmt, after); + new_stmt = Block::make(new_stmt, after); } - return stmt; + return new_stmt; } private: @@ -685,7 +684,7 @@ class CoProcSyncInserter : public IRMutator { Stmt CoProcSync(Stmt stmt) { - return CoProcSyncInserter().Insert(stmt); + return CoProcSyncInserter().Insert(std::move(stmt)); } } // namespace ir From 9dcaaa96b6d74018952fc777fde2980733436643 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 1 Jan 2020 14:04:01 -0800 Subject: [PATCH 04/45] Migrate detect_device --- src/pass/detect_device.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/pass/detect_device.cc b/src/pass/detect_device.cc index cd7c979171a6..202f2556a9fd 100644 --- a/src/pass/detect_device.cc +++ b/src/pass/detect_device.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,7 +22,6 @@ */ #include -#include #include "../pass/ir_util.h" namespace tvm { From 335c2aa40b11a1cc8a2ea26183d033881e6531fc Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 1 Jan 2020 14:25:57 -0800 Subject: [PATCH 05/45] Migrate loop_partition --- include/tvm/ir_functor_ext.h | 29 +++++++++ include/tvm/ir_mutator.h | 21 ------ include/tvm/ir_visitor.h | 9 --- src/api/api_pass.cc | 3 +- src/op/compute_op.cc | 12 ++-- src/op/hybrid_op.cc | 2 +- src/pass/hoist_if_then_else.cc | 4 +- src/pass/loop_partition.cc | 113 ++++++++++++++++----------------- src/schedule/graph.cc | 2 +- 9 files changed, 93 insertions(+), 102 deletions(-) diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index f7a0f2a3b61c..4d821c2c4236 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -546,6 +546,35 @@ class StmtExprMutator : } }; +/*! + * \brief recursively visit the ir in post DFS order node, and transform it + * + * \param node The ir to be transformed. + * \param preorder The function called in before recursive mutation + * If preorder returns None, then the transform will proceed to recursive call. + * If preorder returns a not None Stmt/Expr, the transformer will simply return it and + * won't do further recursion. + * \param postorder The function called after recursive mutation. + * The recursive mutation result is passed to postorder for further mutation. + * \param only_enable List of StringImm. + * If it is empty, all IRNode will call preorder/postorder + * If it is not empty, preorder/postorder will only be called + * when the IRNode's type key is in the list. + */ +TVM_DLL Stmt IRTransform(Stmt node, + const runtime::PackedFunc& preorder, + const runtime::PackedFunc& postorder, + const Array& only_enable = {}); + +/*! + * \brief recursively visit the ir in post DFS order node, apply fvisit + * Each node is guaranteed to be visited only once. + * \param node The ir to be visited. + * \param fvisit The visitor function to be applied. + */ +TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function fvisit); + + } // namespace ir } // namespace tvm #endif // TVM_IR_FUNCTOR_EXT_H_ diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index 702cea3ce8fd..14769586a959 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -122,27 +122,6 @@ class TVM_DLL IRMutator { virtual Expr Mutate_(const StringImm* op, const Expr& e); virtual Expr Mutate_(const Shuffle* op, const Expr& e); }; - - -/*! - * \brief recursively visit the ir in post DFS order node, and transform it - * - * \param node The ir to be transformed. - * \param preorder The function called in before recursive mutation - * If preorder returns None, then the transform will proceed to recursive call. - * If preorder returns a not None Stmt/Expr, the transformer will simply return it and - * won't do further recursion. - * \param postorder The function called after recursive mutation. - * The recursive mutation result is passed to postorder for further mutation. - * \param only_enable List of StringImm. - * If it is empty, all IRNode will call preorder/postorder - * If it is not empty, preorder/postorder will only be called - * when the IRNode's type key is in the list. - */ -Stmt IRTransform(Stmt node, - const runtime::PackedFunc& preorder, - const runtime::PackedFunc& postorder, - const Array& only_enable = {}); } // namespace ir } // namespace tvm #endif // TVM_IR_MUTATOR_H_ diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index cffcdcbdf5b8..e6bd3c6f344d 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -145,15 +145,6 @@ class TVM_DLL IRVisitor { virtual void Visit_(const FloatImm* op); virtual void Visit_(const StringImm* op); }; - -/*! - * \brief recursively visit the ir in post DFS order node, apply fvisit - * Each node is guaranteed to be visited only once. - * \param node The ir to be visited. - * \param fvisit The visitor function to be applied. - */ -TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function fvisit); - } // namespace ir } // namespace tvm diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 339b25a51894..f1d97f49783d 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -25,8 +25,7 @@ #include #include #include -#include -#include +#include #include namespace tvm { diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index c0cae269ffc3..85459b4d723d 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -24,8 +24,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -538,7 +538,7 @@ namespace { * must be Reduce as well; and their inputs should have the * same attribute except value_index. */ -class ComputeVerifier final : protected ir::IRVisitor { +class ComputeVerifier final : protected ir::ExprVisitor { public: /// Special member functions //@{ @@ -567,20 +567,20 @@ class ComputeVerifier final : protected ir::IRVisitor { } level_ = 0; - ir::IRVisitor::Visit(e); + ExprVisitor::VisitExpr(e); } } protected: /// Visitor implementation //@{ - void Visit(const ObjectRef& n) final { + void VisitExpr(const Expr& n) final { ++level_; - ir::IRVisitor::Visit(n); + ExprVisitor::VisitExpr(n); --level_; } - void Visit_(const ir::Reduce* op) final { + void VisitExpr_(const ir::Reduce* op) final { // Check for non top level reductions CHECK(0 == level_) << "Reductions are only allowed at the top level of compute. " diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc index 061929a31ef1..fd64b93a032a 100644 --- a/src/op/hybrid_op.cc +++ b/src/op/hybrid_op.cc @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/pass/hoist_if_then_else.cc b/src/pass/hoist_if_then_else.cc index e3ffcc4f15f3..7a5011358622 100644 --- a/src/pass/hoist_if_then_else.cc +++ b/src/pass/hoist_if_then_else.cc @@ -21,9 +21,7 @@ * \file hoist_if_then_else.cc */ #include -#include -#include -#include +#include #include #include #include diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index e68387f1baad..11cf57490450 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -21,8 +21,7 @@ * \file loop_partition.cc */ #include -#include -#include +#include #include #include #include @@ -50,7 +49,6 @@ struct PartitionKeyHash { // condition cond is proven to have value cond_value (true or false) in interval. using Partition = std::unordered_map; - bool ExprUseVars(Expr expr, const std::unordered_set& vars) { bool success = false; PostOrderVisit(expr, [&vars, &success](const ObjectRef& node) { @@ -68,28 +66,28 @@ bool ExprUseVars(Expr expr, const std::unordered_set& vars) { // Rule: // - the range should not be const // - there exist a condition expression in the scope that use the var -class CandidateSelector final : public IRVisitor { +class CandidateSelector final : public StmtExprVisitor { public: using VarIsUsed = bool; explicit CandidateSelector(bool split_const_loop) : split_const_loop_(split_const_loop) {} - void Visit_(const For* op) { + void VisitStmt_(const For* op) final { // partition const loop when sets split_const_loop_ if (!is_const(op->min) || !is_const(op->extent) || split_const_loop_) { const Variable* var = op->loop_var.get(); record_.insert({var, false}); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); if (record_.at(var) && !no_split_) { candidates.insert(op); } record_.erase(var); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } } - void Visit_(const AttrStmt* op) { + void VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::thread_extent) { const IterVarNode *iv = op->node.as(); CHECK(iv); @@ -97,7 +95,7 @@ class CandidateSelector final : public IRVisitor { runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); if ((scope.rank == 0) && (!is_const(op->value) || split_const_loop_)) { record_.insert({var.get(), false}); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); if (record_.at(var.get()) && !no_split_) { candidates.insert(op); } @@ -105,34 +103,34 @@ class CandidateSelector final : public IRVisitor { return; } } - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } - void Visit_(const Block* op) { + void VisitStmt_(const Block* op) final { bool temp = no_split_; - this->Visit(op->first); + this->VisitStmt(op->first); // erase the no split state of first when visit rest. std::swap(temp, no_split_); - this->Visit(op->rest); + this->VisitStmt(op->rest); // restore the no split flag. no_split_ = no_split_ || temp; } - void Visit_(const Call* op) { + void VisitExpr_(const Call* op) final { if (op->is_intrinsic(Call::likely)) { in_likely_ = true; - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); in_likely_ = false; } else if (op->is_intrinsic(intrinsic::tvm_thread_allreduce)) { // no split if the body contains allreduce. no_split_ = true; return; } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); } } - void Visit_(const Variable* op) { + void VisitExpr_(const Variable* op) final { if (in_likely_ && record_.count(op)) { record_.at(op) = true; } @@ -150,7 +148,7 @@ class CandidateSelector final : public IRVisitor { // Populate partitions data structure, i.e., for a specific variable, // find an interval in which each condition // (currently, "likely" conditions) has fixed true or false value -class PartitionFinder : public IRVisitor { +class PartitionFinder : public StmtExprVisitor { public: explicit PartitionFinder(VarExpr current_var, const std::unordered_map& hint_map, @@ -164,18 +162,18 @@ class PartitionFinder : public IRVisitor { } } - void Visit_(const For* op) { + void VisitStmt_(const For* op) final { if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return; const Variable* var = op->loop_var.get(); hint_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)}); relax_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)}); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); relax_map_.erase(var); hint_map_.erase(var); } - void Visit_(const AttrStmt* op) { + void VisitStmt_(const AttrStmt* op) final { // handle thread_axis if (op->attr_key == attr::thread_extent) { const IterVarNode* thread_axis = op->node.as(); @@ -184,15 +182,15 @@ class PartitionFinder : public IRVisitor { IntSet dom = IntSet::range(Range(make_zero(op->value.dtype()), op->value)); hint_map_.insert({var, dom}); relax_map_.insert({var, dom}); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); relax_map_.erase(var); hint_map_.erase(var); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } } - void Visit_(const Call* op) { + void VisitExpr_(const Call* op) final { if (op->is_intrinsic(Call::likely)) { Expr cond = op->args[0]; if (ExprUseVars(cond, @@ -217,7 +215,7 @@ class PartitionFinder : public IRVisitor { } } } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); } } @@ -255,17 +253,16 @@ class PartitionFinder : public IRVisitor { }; // Replace the set of conditions given by ps with cond_value (true or false) -class ConditionEliminator : public IRMutator { +class ConditionEliminator : public StmtExprMutator { public: explicit ConditionEliminator(const std::unordered_set& ps, bool cond_value = true) : ps_(ps), cond_value_(cond_value) {} - using IRMutator::Mutate; - Expr Mutate(Expr e) final { + Expr VisitExpr(const Expr& e) final { if (ps_.find(e.get()) != ps_.end()) { - return Mutate(cond_value_ ? const_true() : const_false()); + return VisitExpr(cond_value_ ? const_true() : const_false()); } - return IRMutator::Mutate(e); + return StmtExprMutator::VisitExpr(e); } private: @@ -275,26 +272,26 @@ class ConditionEliminator : public IRMutator { // Insert the partition branch at the innermost thread scope -class ThreadPartitionInserter : public IRMutator { +class ThreadPartitionInserter : public StmtMutator { public: explicit ThreadPartitionInserter(const std::unordered_set& ps, Expr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::thread_extent) { innermost_thread_scope_ = true; - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtMutator::VisitStmt_(op); // add branch code inside the innermost thread scope if (innermost_thread_scope_) { - Stmt simplified_body = ConditionEliminator(ps_).Mutate(op->body); + Stmt simplified_body = ConditionEliminator(ps_)(op->body); Stmt body = IfThenElse::make(cond_, simplified_body, op->body); - Expr value = this->Mutate(op->value); + Expr value = this->VisitExpr(op->value); stmt = AttrStmt::make(op->node, op->attr_key, value, body); } innermost_thread_scope_ = false; return stmt; } else { - return IRMutator::Mutate_(op, s); + return StmtMutator::VisitStmt_(op); } } @@ -306,19 +303,19 @@ class ThreadPartitionInserter : public IRMutator { // Try to partition range of iteration variables in order to remove (some) // likely conditions -class LoopPartitioner : public IRMutator { +class LoopPartitioner : public StmtMutator { public: explicit LoopPartitioner(bool split_const_loop) : selector(CandidateSelector(split_const_loop)) {} - Stmt VisitAndMutate(const Stmt& stmt) { - selector.Visit(stmt); - return Mutate(stmt); + Stmt VisitAndMutate(Stmt stmt) { + selector(stmt); + return operator()(std::move(stmt)); } - Stmt Mutate_(const For* op, const Stmt& stmt) { + Stmt VisitStmt_(const For* op) final { if (selector.candidates.count(op)) { - Stmt s = TryPartition(op, stmt, op->loop_var, + Stmt s = TryPartition(op, GetRef(op), op->loop_var, op->min, op->min + op->extent - 1, op->body, false); if (s.defined()) return s; } @@ -327,21 +324,21 @@ class LoopPartitioner : public IRMutator { // normal loop variable can be put into hint map. hint_map_.insert({op->loop_var.get(), IntSet::interval(op->min, op->min + op->extent - 1)}); - Stmt res = IRMutator::Mutate_(op, stmt); + Stmt res = StmtMutator::VisitStmt_(op); hint_map_.erase(op->loop_var.get()); return res; } - Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key != attr::thread_extent) { - return IRMutator::Mutate_(op, stmt); + return StmtMutator::VisitStmt_(op); } const IterVarNode *iv = op->node.as(); CHECK(iv); Var var = iv->var; if (selector.candidates.count(op)) { - Stmt s = TryPartition(op, stmt, var, 0, op->value - 1, op->body, true); + Stmt s = TryPartition(op, GetRef(op), var, 0, op->value - 1, op->body, true); if (s.defined()) return s; } @@ -352,12 +349,12 @@ class LoopPartitioner : public IRMutator { // threadIdx should be put into relax map, in case of divergence. relax_map_.insert({var.get(), IntSet::interval(make_zero(var.dtype()), op->value - 1)}); - res = IRMutator::Mutate_(op, stmt); + res = StmtMutator::VisitStmt_(op); relax_map_.erase(var.get()); } else { hint_map_.insert({var.get(), IntSet::interval(make_zero(var.dtype()), op->value - 1)}); - res = IRMutator::Mutate_(op, stmt); + res = StmtMutator::VisitStmt_(op); hint_map_.erase(var.get()); } return res; @@ -473,7 +470,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, hint_map_.insert({var.get(), IntSet::interval(min, max)}); PartitionFinder finder(var, hint_map_, relax_map_); - finder.Visit(body); + finder(body); hint_map_.erase(var.get()); if (finder.partitions.empty()) return Stmt(); @@ -564,7 +561,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, Stmt mid_stmt; if (!analyzer_.CanProve(body_begin >= post_doubt_begin)) { // [body_begin, post_doubt_begin) - Stmt simplified_body = ConditionEliminator(cond_set, cond_value).Mutate(body); + Stmt simplified_body = ConditionEliminator(cond_set, cond_value)(body); Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}}); mid_stmt = MakeFor(node, post_doubt_begin - body_begin, new_body); @@ -586,7 +583,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, Expr cond = const_true(); if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin); if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin); - s = ThreadPartitionInserter(cond_set, cond).Mutate(stmt); + s = ThreadPartitionInserter(cond_set, cond)(stmt); } s = ConvertSSA(s); return s; @@ -604,23 +601,21 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, Expr extent, Stmt body) } } -class RemoveLikelyTags : public IRMutator { +class RemoveLikelyTags : public StmtExprMutator { public: - using IRMutator::Mutate; - - Expr Mutate_(const Call *op, const Expr& e) { + Expr VisitExpr_(const Call *op) final { if (op->is_intrinsic(Call::likely)) { CHECK_EQ(op->args.size(), 1); - return IRMutator::Mutate(op->args[0]); + return StmtExprMutator::VisitExpr(op->args[0]); } else { - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } } }; Stmt LoopPartition(Stmt stmt, bool split_const_loop) { - stmt = LoopPartitioner(split_const_loop).VisitAndMutate(stmt); - stmt = RemoveLikelyTags().Mutate(stmt); + stmt = LoopPartitioner(split_const_loop).VisitAndMutate(std::move(stmt)); + stmt = RemoveLikelyTags()(std::move(stmt)); return stmt; } diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index c3024a71977f..a5ed43601024 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -22,7 +22,7 @@ * \brief Utilities to get information about schedule graph. */ #include -#include +#include #include #include #include From 5b58263596dc4f113ac5c23cefa7ebc194e71198 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 1 Jan 2020 14:31:04 -0800 Subject: [PATCH 06/45] Migrate infer_fragement --- src/pass/infer_fragment.cc | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/pass/infer_fragment.cc b/src/pass/infer_fragment.cc index 13f9ebade9b1..951e0083cdcd 100644 --- a/src/pass/infer_fragment.cc +++ b/src/pass/infer_fragment.cc @@ -23,8 +23,7 @@ */ #include #include -#include -#include +#include #include #include #include "ir_util.h" @@ -35,7 +34,7 @@ namespace tvm { namespace ir { // Get fragment information from tensor intrinsics -class FragmentGetter : public IRVisitor { +class FragmentGetter : public StmtExprVisitor { public: // fragment metadata struct FragmentInfo { @@ -48,8 +47,8 @@ class FragmentGetter : public IRVisitor { : m(_m), n(_n), k(_k), layout(_layout) {} }; - void Visit_(const Call* op) final { - IRVisitor::Visit_(op); + void VisitExpr_(const Call* op) final { + StmtExprVisitor::VisitExpr_(op); if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) || op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) { @@ -116,13 +115,13 @@ class FragmentGetter : public IRVisitor { } // Get memory scope - void Visit_(const AttrStmt* op) final { + void VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::storage_scope) { const Variable* buffer = op->node.as(); CHECK(buffer); scopes[buffer] = op->value.as()->value; } - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } // Memory scope for allocations @@ -132,11 +131,12 @@ class FragmentGetter : public IRVisitor { }; // Check shape of fragment making sure it is a valid shape for tvm_mma_sync -class FragmentChecker : public IRVisitor { +class FragmentChecker : public StmtExprVisitor { public: explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {} - void Visit_(const Call* op) final { + void VisitExpr_(const Call* op) final { + StmtExprVisitor::VisitExpr_(op); // Check shape when calling tvm_mma_sync if (op->is_intrinsic(intrinsic::tvm_mma_sync)) { CHECK_EQ(op->args.size(), 8U); @@ -170,12 +170,12 @@ class FragmentChecker : public IRVisitor { }; // Store the metadata into attributes -class InferFragmenter : public IRMutator { +class InferFragmenter : public StmtMutator { public: explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {} - Stmt Mutate_(const Allocate* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const Allocate* op) final { + Stmt stmt = StmtMutator::VisitStmt_(op); const Variable* buffer = op->buffer_var.get(); if (fragment_getter.fragments.count(buffer)) { // Add attribute to fragments allocation @@ -206,9 +206,10 @@ class InferFragmenter : public IRMutator { Stmt InferFragment(Stmt stmt) { FragmentGetter getter; - getter.Visit(stmt); - FragmentChecker(getter).Visit(stmt); - stmt = InferFragmenter(getter).Mutate(stmt); + getter(stmt); + FragmentChecker checker(getter); + checker(stmt); + stmt = InferFragmenter(getter)(std::move(stmt)); return stmt; } From 6e853f364585a62d19138b604c0ea2388ed7d281 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 1 Jan 2020 14:32:35 -0800 Subject: [PATCH 07/45] Migrate inject_copy_intrin --- src/pass/inject_copy_intrin.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc index 7b7c5df48236..d1ba19b9fb05 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/pass/inject_copy_intrin.cc @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include "../arithmetic/pattern_match.h" @@ -32,7 +32,7 @@ namespace ir { using runtime::PackedFunc; -class CopyIntrinInjector : public IRMutator { +class CopyIntrinInjector : public StmtMutator { public: CopyIntrinInjector(const std::string& pragma_key, const PackedFunc& flower_copy_fromto) @@ -40,7 +40,7 @@ class CopyIntrinInjector : public IRMutator { flower_copy_fromto_(flower_copy_fromto) { } - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::storage_scope) { const Variable* buf = op->node.as(); storage_scope_[buf] = op->value.as()->value; @@ -50,7 +50,7 @@ class CopyIntrinInjector : public IRMutator { << "Cannot match copy pattern of " << op->body; return ret; } - return IRMutator::Mutate_(op, s); + return StmtMutator::VisitStmt_(op); } private: @@ -193,8 +193,7 @@ class CopyIntrinInjector : public IRMutator { Stmt InjectCopyIntrin(Stmt stmt, const std::string& pragma_key, const PackedFunc& flower_copy_fromto) { - return CopyIntrinInjector(pragma_key, flower_copy_fromto) - .Mutate(stmt); + return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt)); } } // namespace ir From 5407c8df5d9358a24f68b9ae2ee01b10cf715c02 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 1 Jan 2020 14:38:51 -0800 Subject: [PATCH 08/45] Migrate inject double buffer --- src/pass/inject_double_buffer.cc | 69 ++++++++++++++++---------------- 1 file changed, 34 insertions(+), 35 deletions(-) diff --git a/src/pass/inject_double_buffer.cc b/src/pass/inject_double_buffer.cc index 78d3305d3e17..84b2f705e995 100644 --- a/src/pass/inject_double_buffer.cc +++ b/src/pass/inject_double_buffer.cc @@ -22,8 +22,7 @@ * \file inject_double_buffer.cc */ #include -#include -#include +#include #include #include "ir_util.h" #include "../arithmetic/compute_expr.h" @@ -32,18 +31,18 @@ namespace tvm { namespace ir { // Detect double buffer variables. -class DoubleBufferDetector : public IRVisitor { +class DoubleBufferDetector : public StmtExprVisitor { public: - void Visit_(const AttrStmt* op) final { + void VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::double_buffer_scope) { touched_.insert(op->node.as()); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } } - void Visit_(const Variable* op) final { + void VisitExpr_(const Variable* op) final { if (touched_.count(op)) { touched_.erase(op); } @@ -53,55 +52,55 @@ class DoubleBufferDetector : public IRVisitor { }; -class StripDoubleBufferWrite : public IRMutator { +class StripDoubleBufferWrite : public StmtMutator { public: - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::double_buffer_write) { - return Mutate(op->body); + return VisitStmt(op->body); } else { - return IRMutator::Mutate_(op, s); + return StmtMutator::VisitStmt_(op); } } }; -class DoubleBufferInjector : public IRMutator { +class DoubleBufferInjector : public StmtExprMutator { public: explicit DoubleBufferInjector(int split_loop) : split_loop_(split_loop) {} - Stmt Inject(const Stmt& stmt) { + Stmt Inject(Stmt stmt) { DoubleBufferDetector detector; - detector.Visit(stmt); + detector(stmt); if (detector.touched_.empty()) return stmt; for (const Variable* v : detector.touched_) { dbuffer_info_[v] = StorageEntry(); } - return ConvertSSA(this->Mutate(stmt)); + return ConvertSSA(operator()(std::move(stmt))); } - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::storage_scope) { const Variable* buf = op->node.as(); auto it = dbuffer_info_.find(buf); if (it != dbuffer_info_.end()) { it->second.scope = op->value.as()->value; - return Mutate(op->body); + return this->VisitStmt(op->body); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } else if (op->attr_key == attr::double_buffer_scope) { - return MakeProducer(op, s); + return MakeProducer(op); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const Allocate* op, const Stmt& s) final { + Stmt VisitStmt_(const Allocate* op) final { auto it = dbuffer_info_.find(op->buffer_var.get()); if (it != dbuffer_info_.end()) { it->second.stride = arith::ComputeReduce( op->extents, Expr()) * op->dtype.lanes(); - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); Array new_extents{make_const(op->extents[0].dtype(), 2)}; for (Expr e : op->extents) { @@ -118,13 +117,13 @@ class DoubleBufferInjector : public IRMutator { Evaluate::make(0))); return op->body; } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const For* op, const Stmt& s) final { + Stmt VisitStmt_(const For* op) final { loop_nest_.push_back(op); - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); auto it = loop_pre_.find(op); if (it != loop_pre_.end()) { const For* old_loop = stmt.as(); @@ -151,7 +150,7 @@ class DoubleBufferInjector : public IRMutator { MergeSeq(loop_seq)); // tail std::vector tail_seq; - Stmt tail_body = StripDoubleBufferWrite().Mutate(old_loop->body); + Stmt tail_body = StripDoubleBufferWrite()(old_loop->body); for (int32_t i = 0; i < split_loop_; ++i) { Expr idx = tail_base + make_const(tail_base.dtype(), i); vmap[old_loop->loop_var.get()] = idx; @@ -171,8 +170,8 @@ class DoubleBufferInjector : public IRMutator { return stmt; } - Stmt Mutate_(const Store* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const Store* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); auto it = dbuffer_info_.find(op->buffer_var.get()); if (it != dbuffer_info_.end()) { @@ -188,8 +187,8 @@ class DoubleBufferInjector : public IRMutator { } } - Expr Mutate_(const Load* op, const Expr& e) final { - Expr expr = IRMutator::Mutate_(op, e); + Expr VisitExpr_(const Load* op) final { + Expr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); auto it = dbuffer_info_.find(op->buffer_var.get()); if (it != dbuffer_info_.end()) { @@ -205,20 +204,20 @@ class DoubleBufferInjector : public IRMutator { } } - Expr Mutate_(const Variable* op, const Expr& e) final { + Expr VisitExpr_(const Variable* op) final { CHECK(!dbuffer_info_.count(op)); - return e; + return GetRef(op); } private: - Stmt MakeProducer(const AttrStmt* op, const Stmt& s) { + Stmt MakeProducer(const AttrStmt* op) { const VarExpr buffer = Downcast(op->node); CHECK_NE(loop_nest_.size(), 0U) << "Double buffer scope must be inside a loop"; auto it = dbuffer_info_.find(buffer.get()); if (it == dbuffer_info_.end()) { LOG(WARNING) << "Skip double buffer scope " << op->node; - return Mutate(op->body); + return this->VisitStmt(op->body); } StorageEntry& e = it->second; e.loop = loop_nest_.back(); @@ -230,7 +229,7 @@ class DoubleBufferInjector : public IRMutator { e.loop->loop_var.dtype()); e.switch_read_var = indexmod(e.loop->loop_var, two); in_double_buffer_scope_ = true; - Stmt body = Mutate(op->body); + Stmt body = this->VisitStmt(op->body); in_double_buffer_scope_ = false; std::unordered_map vmap; vmap[e.switch_write_var.get()] = zero; From eefd692dcbdde1526d6ba3cadda497186a2a59fe Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 1 Jan 2020 15:19:26 -0800 Subject: [PATCH 09/45] Migrate lower_intrin and simplify --- src/arithmetic/canonical_simplify.cc | 76 +++++++-------- src/arithmetic/ir_mutator_with_analyzer.cc | 91 ++++++++++-------- src/arithmetic/ir_mutator_with_analyzer.h | 27 +++--- src/arithmetic/rewrite_simplify.cc | 102 ++++++++++----------- src/arithmetic/rewrite_simplify.h | 52 +++++------ src/arithmetic/stmt_simplify.cc | 37 ++++---- src/pass/lower_intrin.cc | 64 +++++++------ 7 files changed, 234 insertions(+), 215 deletions(-) diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 6a19a7aeb3f2..9ad69a61ed5f 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -435,30 +435,30 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { Expr CanonicalSimplify(Expr expr) { - expr = Mutate(expr); + expr = operator()(expr); return expr; } // override the original mutate function. - Expr Mutate(Expr expr) final { - expr = IRMutator::Mutate(expr); + Expr VisitExpr(const Expr& input_expr) final { + auto expr = Rewriter::VisitExpr(input_expr); return Normalize(expr); } // Normal mutation without normalization. Expr CanonicalMutate(Expr expr) { - return IRMutator::Mutate(expr); + return Rewriter::VisitExpr(expr); } - using Rewriter::Mutate_; - Expr Mutate_(const Add* op, const Expr& self) final; - Expr Mutate_(const Sub* op, const Expr& self) final; - Expr Mutate_(const Mul* op, const Expr& self) final; - Expr Mutate_(const Div* op, const Expr& self) final; - Expr Mutate_(const Mod* op, const Expr& self) final; - Expr Mutate_(const FloorDiv* op, const Expr& self) final; - Expr Mutate_(const FloorMod* op, const Expr& self) final; - Expr Mutate_(const Reduce* op, const Expr& self) final; + using Rewriter::VisitExpr_; + Expr VisitExpr_(const Add* op) final; + Expr VisitExpr_(const Sub* op) final; + Expr VisitExpr_(const Mul* op) final; + Expr VisitExpr_(const Div* op) final; + Expr VisitExpr_(const Mod* op) final; + Expr VisitExpr_(const FloorDiv* op) final; + Expr VisitExpr_(const FloorMod* op) final; + Expr VisitExpr_(const Reduce* op) final; private: /*! @@ -567,9 +567,9 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { }; Expr CanonicalSimplifier::Impl:: -Mutate_(const Add* op, const Expr& self) { +VisitExpr_(const Add* op) { if (!IsIndexType(op->dtype)) { - return Rewriter::Mutate_(op, self); + return Rewriter::VisitExpr_(op); } // normalize Expr a = this->CanonicalMutate(op->a); @@ -593,9 +593,9 @@ Mutate_(const Add* op, const Expr& self) { } Expr CanonicalSimplifier::Impl:: -Mutate_(const Sub* op, const Expr& self) { +VisitExpr_(const Sub* op) { if (!IsIndexType(op->dtype)) { - return Rewriter::Mutate_(op, self); + return Rewriter::VisitExpr_(op); } // normalize Expr a = this->CanonicalMutate(op->a); @@ -620,9 +620,9 @@ Mutate_(const Sub* op, const Expr& self) { Expr CanonicalSimplifier::Impl:: -Mutate_(const Mul* op, const Expr& self) { +VisitExpr_(const Mul* op) { if (!IsIndexType(op->dtype)) { - return Rewriter::Mutate_(op, self); + return Rewriter::VisitExpr_(op); } // normalize Expr a = this->CanonicalMutate(op->a); @@ -652,7 +652,7 @@ Mutate_(const Mul* op, const Expr& self) { a = Normalize(a); b = Normalize(b); if (op->a.same_as(a) && op->b.same_as(b)) { - return self; + return GetRef(op); } else { return Mul::make(a, b); } @@ -727,9 +727,9 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { } Expr CanonicalSimplifier::Impl:: -Mutate_(const Div* op, const Expr& self) { +VisitExpr_(const Div* op) { if (!IsIndexType(op->dtype)) { - return Rewriter::Mutate_(op, self); + return Rewriter::VisitExpr_(op); } Expr a = this->CanonicalMutate(op->a); @@ -781,16 +781,16 @@ Mutate_(const Div* op, const Expr& self) { a = Normalize(a); b = Normalize(b); if (op->a.same_as(a) && op->b.same_as(b)) { - return self; + return GetRef(op); } else { return Div::make(a, b); } } Expr CanonicalSimplifier::Impl:: -Mutate_(const FloorDiv* op, const Expr& self) { +VisitExpr_(const FloorDiv* op) { if (!IsIndexType(op->dtype)) { - return Rewriter::Mutate_(op, self); + return Rewriter::VisitExpr_(op); } Expr a = this->CanonicalMutate(op->a); Expr b = this->CanonicalMutate(op->b); @@ -837,7 +837,7 @@ Mutate_(const FloorDiv* op, const Expr& self) { a = Normalize(a); b = Normalize(b); if (op->a.same_as(a) && op->b.same_as(b)) { - return self; + return GetRef(op); } else { return FloorDiv::make(a, b); } @@ -866,7 +866,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { // Do a recursive call to simplify the mod with the new factor. if (new_upper_factor < lhs->upper_factor && lhs->upper_factor != SplitExprNode::kPosInf) { - auto updated = ToSplitExpr(Mutate(ModImpl( + auto updated = ToSplitExpr(this->VisitExpr(ModImpl( lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode))); // re-apply the lower_factor if (lhs->lower_factor != 1) { @@ -894,9 +894,9 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { } Expr CanonicalSimplifier::Impl:: -Mutate_(const Mod* op, const Expr& self) { +VisitExpr_(const Mod* op) { if (!IsIndexType(op->dtype)) { - return Rewriter::Mutate_(op, self); + return Rewriter::VisitExpr_(op); } // normalize Expr a = this->CanonicalMutate(op->a); @@ -957,16 +957,16 @@ Mutate_(const Mod* op, const Expr& self) { a = Normalize(a); b = Normalize(b); if (op->a.same_as(a) && op->b.same_as(b)) { - return self; + return GetRef(op); } else { return Mod::make(a, b); } } Expr CanonicalSimplifier::Impl:: -Mutate_(const FloorMod* op, const Expr& self) { +VisitExpr_(const FloorMod* op) { if (!IsIndexType(op->dtype)) { - return Rewriter::Mutate_(op, self); + return Rewriter::VisitExpr_(op); } // normalize Expr a = this->CanonicalMutate(op->a); @@ -1017,7 +1017,7 @@ Mutate_(const FloorMod* op, const Expr& self) { a = Normalize(a); b = Normalize(b); if (op->a.same_as(a) && op->b.same_as(b)) { - return self; + return GetRef(op); } else { return FloorMod::make(a, b); } @@ -1029,7 +1029,7 @@ SimplifyReduceCombiner(const Reduce* op) { // First simplify the results Array simplified_result; for (const auto& res : op->combiner->result) { - Expr new_res = Mutate(res); + Expr new_res = this->VisitExpr(res); simplified_result.push_back(new_res); } @@ -1078,7 +1078,7 @@ SimplifyReduceCombiner(const Reduce* op) { if (used[i]) { // We simplify the result and identity, but not the source new_result.push_back(simplified_result[i]); - new_identity.push_back(Mutate(op->combiner->identity_element[i])); + new_identity.push_back(this->VisitExpr(op->combiner->identity_element[i])); new_lhs.push_back(op->combiner->lhs[i]); new_rhs.push_back(op->combiner->rhs[i]); new_source.push_back(op->source[i]); @@ -1095,9 +1095,9 @@ SimplifyReduceCombiner(const Reduce* op) { } Expr CanonicalSimplifier::Impl:: -Mutate_(const Reduce* op, const Expr& self) { +VisitExpr_(const Reduce* op) { // Recursively call simplification when necessary. - Expr ret = RewriteSimplifier::Impl::Mutate_(op, self); + Expr ret = RewriteSimplifier::Impl::VisitExpr_(op); op = ret.as(); // already been simplified by const reduction axis removal if (op == nullptr) return ret; @@ -1106,7 +1106,7 @@ Mutate_(const Reduce* op, const Expr& self) { // assumption we would have to perform a single iteration of the loop, i.e. use // `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]` // instead of `op->source[op->value_index]`. The former may be more difficult to simplify. - return Mutate( + return this->VisitExpr( Select::make(op->condition, op->source[op->value_index], op->combiner->identity_element[op->value_index])); diff --git a/src/arithmetic/ir_mutator_with_analyzer.cc b/src/arithmetic/ir_mutator_with_analyzer.cc index 0d4b8f26b18b..bfce2c26fbe3 100644 --- a/src/arithmetic/ir_mutator_with_analyzer.cc +++ b/src/arithmetic/ir_mutator_with_analyzer.cc @@ -30,41 +30,44 @@ namespace arith { using namespace ir; Stmt IRMutatorWithAnalyzer:: -Mutate_(const For* op, const Stmt& s) { +VisitStmt_(const For* op) { analyzer_->Bind(op->loop_var, - Range::make_by_min_extent(op->min, op->extent)); - return IRMutator::Mutate_(op, s); + Range::make_by_min_extent(op->min, op->extent)); + return StmtExprMutator::VisitStmt_(op); } Stmt IRMutatorWithAnalyzer:: -Mutate_(const LetStmt* op, const Stmt& s) { - Expr value = this->Mutate(op->value); +VisitStmt_(const LetStmt* op) { + Expr value = this->VisitExpr(op->value); if (!ir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); } // We keep the let-binding here // as sub-class may or maynot choose to replace it. - Stmt body = this->Mutate(op->body); + Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return s; + return GetRef(op); } else { - return LetStmt::make(op->var, value, body); + auto n = this->CopyOnWrite(op); + n->value = std::move(value); + n->body = std::move(body); + return Stmt(n); } } Stmt IRMutatorWithAnalyzer:: -Mutate_(const IfThenElse* op, const Stmt& s) { - Expr condition = this->Mutate(op->condition); +VisitStmt_(const IfThenElse* op) { + Expr condition = this->VisitExpr(op->condition); Stmt then_case, else_case; { With ctx(analyzer_, condition); - then_case = this->Mutate(op->then_case); + then_case = this->VisitStmt(op->then_case); } if (op->else_case.defined()) { With ctx(analyzer_, analyzer_->rewrite_simplify(Not::make(condition))); - else_case = this->Mutate(op->else_case); + else_case = this->VisitStmt(op->else_case); } if (is_one(condition)) return then_case; if (is_zero(condition)) { @@ -77,57 +80,65 @@ Mutate_(const IfThenElse* op, const Stmt& s) { if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { - return s; + return GetRef(op); } else { - return IfThenElse::make(condition, then_case, else_case); + auto n = this->CopyOnWrite(op); + n->condition = std::move(condition); + n->then_case = std::move(then_case); + n->else_case = std::move(else_case); + return Stmt(n); } } Stmt IRMutatorWithAnalyzer:: -Mutate_(const AttrStmt* op, const Stmt& s) { +VisitStmt_(const AttrStmt* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); return stmt; } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } Stmt IRMutatorWithAnalyzer:: -Mutate_(const AssertStmt* op, const Stmt& s) { - Expr condition = this->Mutate(op->condition); - Expr message = this->Mutate(op->message); +VisitStmt_(const AssertStmt* op) { + Expr condition = this->VisitExpr(op->condition); + Expr message = this->VisitExpr(op->message); With ctx(analyzer_, condition); - Stmt body = this->Mutate(op->body); + Stmt body = this->VisitStmt(op->body); if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) { - return s; + return GetRef(op); } else { - return AssertStmt::make(condition, message, body); + auto n = this->CopyOnWrite(op); + n->condition = std::move(condition); + n->message = std::move(message); + n->body = std::move(body); + return Stmt(n); } } Expr IRMutatorWithAnalyzer:: -Mutate_(const Call* op, const Expr& self) { +VisitExpr_(const Call* op) { // add condition context to if_then_else if (op->is_intrinsic(ir::intrinsic::tvm_if_then_else)) { - Expr cond = Mutate(op->args[0]); + Expr cond = this->VisitExpr(op->args[0]); Expr true_value, false_value; { With constraint(analyzer_, cond); - true_value = Mutate(op->args[1]); + true_value = this->VisitExpr(op->args[1]); } { With constraint(analyzer_, analyzer_->rewrite_simplify(Not::make(cond))); - false_value = Mutate(op->args[2]); + false_value = this->VisitExpr(op->args[2]); } if (is_zero(cond)) { return false_value; @@ -138,45 +149,45 @@ Mutate_(const Call* op, const Expr& self) { if (cond.same_as(op->args[0]) && true_value.same_as(op->args[1]) && false_value.same_as(op->args[2])) { - return self; + return GetRef(op); } else { return Call::make(op->dtype, op->name, {cond, true_value, false_value}, op->call_type); } } - return IRMutator::Mutate_(op, self); + return StmtExprMutator::VisitExpr_(op); } Expr IRMutatorWithAnalyzer:: -Mutate_(const Let* op, const Expr& self) { - Expr value = this->Mutate(op->value); +VisitExpr_(const Let* op) { + Expr value = this->VisitExpr(op->value); if (!ir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); } // We keep the let-binding here // as sub-class may or maynot choose to replace it. - Expr body = this->Mutate(op->body); + Expr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return self; + return GetRef(op); } else { return Let::make(op->var, value, body); } } Expr IRMutatorWithAnalyzer:: -Mutate_(const Select* op, const Expr& self) { - Expr cond = Mutate(op->condition); +VisitExpr_(const Select* op) { + Expr cond = this->VisitExpr(op->condition); Expr true_value, false_value; { With constraint(analyzer_, cond); - true_value = Mutate(op->true_value); + true_value = VisitExpr(op->true_value); } { With constraint(analyzer_, analyzer_->rewrite_simplify(Not::make(cond))); - false_value = Mutate(op->false_value); + false_value = VisitExpr(op->false_value); } if (is_zero(cond)) { return false_value; @@ -188,20 +199,20 @@ Mutate_(const Select* op, const Expr& self) { if (cond.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { - return self; + return GetRef(op); } else { return Select::make(cond, true_value, false_value); } } Expr IRMutatorWithAnalyzer:: -Mutate_(const Reduce* op, const Expr& self) { +VisitExpr_(const Reduce* op) { // Setup the domain information before simplification. for (const IterVar& iv : op->axis) { analyzer_->Bind(iv->var, iv->dom); } // Recursively call simplification when necessary. - return IRMutator::Mutate_(op, self); + return StmtExprMutator::VisitExpr_(op); } } // namespace arith diff --git a/src/arithmetic/ir_mutator_with_analyzer.h b/src/arithmetic/ir_mutator_with_analyzer.h index bf4118e9c698..9e3a86bb5280 100644 --- a/src/arithmetic/ir_mutator_with_analyzer.h +++ b/src/arithmetic/ir_mutator_with_analyzer.h @@ -24,9 +24,9 @@ #ifndef TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_ #define TVM_ARITHMETIC_IR_MUTATOR_WITH_ANALYZER_H_ -#include +#include #include - +#include namespace tvm { namespace arith { @@ -40,23 +40,24 @@ namespace arith { * * \sa src/arithmetic/ir_mutator_with_analyzer.cc */ -class IRMutatorWithAnalyzer : public ir::IRMutator { +class IRMutatorWithAnalyzer : public ir::StmtExprMutator { public: explicit IRMutatorWithAnalyzer(Analyzer* analyzer) : analyzer_(analyzer) {} - using IRMutator::Mutate_; + using StmtExprMutator::VisitStmt_; + using StmtExprMutator::VisitExpr_; // override functions that need to populate the context information. - Stmt Mutate_(const ir::For* op, const Stmt& self) override; - Stmt Mutate_(const ir::LetStmt* op, const Stmt& self) override; - Stmt Mutate_(const ir::IfThenElse* op, const Stmt& self) override; - Stmt Mutate_(const ir::AttrStmt* op, const Stmt& self) override; - Stmt Mutate_(const ir::AssertStmt* op, const Stmt& self) override; - Expr Mutate_(const ir::Let* op, const Expr& self) override; - Expr Mutate_(const ir::Select* op, const Expr& self) override; - Expr Mutate_(const ir::Call* op, const Expr& self) override; - Expr Mutate_(const ir::Reduce* op, const Expr& self) override; + Stmt VisitStmt_(const ir::For* op) override; + Stmt VisitStmt_(const ir::LetStmt* op) override; + Stmt VisitStmt_(const ir::IfThenElse* op) override; + Stmt VisitStmt_(const ir::AttrStmt* op) override; + Stmt VisitStmt_(const ir::AssertStmt* op) override; + Expr VisitExpr_(const ir::Let* op) override; + Expr VisitExpr_(const ir::Select* op) override; + Expr VisitExpr_(const ir::Call* op) override; + Expr VisitExpr_(const ir::Reduce* op) override; protected: /*! \brief internal analyzer field. */ diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 235306cc7bf8..60cc63ff2b96 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -69,7 +69,7 @@ using namespace ir; // try to prove x equals val RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl:: TryCompare(const Expr& x, int64_t val) { - Expr diff = Mutate(x); + Expr diff = this->VisitExpr(x); if (const auto* ptr = diff.as()) { if (ptr->value == val) { return kEQ; @@ -117,8 +117,8 @@ Update(const Var& var, const Expr& info, bool override) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Add* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Add* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -232,8 +232,8 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const Expr& const } Expr RewriteSimplifier::Impl:: -Mutate_(const Sub* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Sub* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -431,8 +431,8 @@ Mutate_(const Sub* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Mul* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Mul* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -470,8 +470,8 @@ Mutate_(const Mul* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Div* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Div* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as
(); Expr const_res = TryConstFold
(op->a, op->b); if (const_res.defined()) return const_res; @@ -692,8 +692,8 @@ Mutate_(const Div* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Mod* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Mod* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -782,8 +782,8 @@ Mutate_(const Mod* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const FloorDiv* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const FloorDiv* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -926,8 +926,8 @@ Mutate_(const FloorDiv* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const FloorMod* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const FloorMod* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -996,8 +996,8 @@ Mutate_(const FloorMod* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Min* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Min* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -1181,8 +1181,8 @@ Mutate_(const Min* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Max* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Max* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -1354,8 +1354,8 @@ Mutate_(const Max* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const EQ* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const EQ* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -1388,28 +1388,28 @@ Mutate_(const EQ* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const NE* op, const Expr& self) { - return Mutate(Not::make(op->a == op->b)); +VisitExpr_(const NE* op) { + return this->VisitExpr(Not::make(op->a == op->b)); } Expr RewriteSimplifier::Impl:: -Mutate_(const LE* op, const Expr& self) { - return Mutate(Not::make(op->b < op->a)); +VisitExpr_(const LE* op) { + return this->VisitExpr(Not::make(op->b < op->a)); } Expr RewriteSimplifier::Impl:: -Mutate_(const GT* op, const Expr& self) { - return Mutate(op->b < op->a); +VisitExpr_(const GT* op) { + return this->VisitExpr(op->b < op->a); } Expr RewriteSimplifier::Impl:: -Mutate_(const GE* op, const Expr& self) { - return Mutate(Not::make(op->a < op->b)); +VisitExpr_(const GE* op) { + return this->VisitExpr(Not::make(op->a < op->b)); } Expr RewriteSimplifier::Impl:: -Mutate_(const LT* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const LT* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -1564,8 +1564,8 @@ Mutate_(const LT* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Not* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Not* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a); if (const_res.defined()) return const_res; @@ -1589,8 +1589,8 @@ Mutate_(const Not* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const And* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const And* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -1638,8 +1638,8 @@ Mutate_(const And* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Or* op, const Expr& self) { - Expr ret = IRMutator::Mutate_(op, self); +VisitExpr_(const Or* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; @@ -1688,8 +1688,8 @@ Mutate_(const Or* op, const Expr& self) { } Expr RewriteSimplifier::Impl:: -Mutate_(const Select* op, const Expr& self) { - Expr ret = IRMutatorWithAnalyzer::Mutate_(op, self); +VisitExpr_(const Select* op) { + Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); UnsafeExprDetector unsafe; bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar(); @@ -131,7 +130,7 @@ class UnsafeSelectRewriter : public IRMutator { }; Stmt RewriteUnsafeSelect(Stmt stmt) { - return UnsafeSelectRewriter().Mutate(stmt); + return UnsafeSelectRewriter()(std::move(stmt)); } } // namespace ir From bf21e7eec5c0273e1744fdf43b09518339a5b0ad Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 09:29:36 -0800 Subject: [PATCH 23/45] Migrate skip_assert simple_passes --- src/pass/simple_passes.cc | 41 +++++++++++++++++++-------------------- src/pass/skip_assert.cc | 10 +++++----- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/src/pass/simple_passes.cc b/src/pass/simple_passes.cc index 1159e568f519..e9ed89304167 100644 --- a/src/pass/simple_passes.cc +++ b/src/pass/simple_passes.cc @@ -22,25 +22,24 @@ * \brief Implementation of simple passes */ #include -#include -#include +#include #include namespace tvm { namespace ir { -class IRSideEffect : public IRVisitor { +class IRSideEffect : public ExprVisitor { public: - void Visit(const ObjectRef& e) final { + void VisitExpr(const Expr& e) final { if (has_side_effect_) return; - IRVisitor::Visit(e); + ExprVisitor::VisitExpr(e); } - void Visit_(const Call* op) final { + void VisitExpr_(const Call* op) final { if (!op->is_pure()) { has_side_effect_ = true; return; } else { - IRVisitor::Visit_(op); + ExprVisitor::VisitExpr_(op); } } @@ -49,23 +48,23 @@ class IRSideEffect : public IRVisitor { bool HasSideEffect(const Expr& e) { IRSideEffect v; - v.Visit(e); + v(e); return v.has_side_effect_; } -class IRSubstitue : public IRMutator { +class IRSubstitue : public StmtExprMutator { public: explicit IRSubstitue( const std::unordered_map& smap) : smap_(smap) { } - Expr Mutate_(const Variable* op, const Expr& e) final { + Expr VisitExpr_(const Variable* op) final { auto it = smap_.find(op); if (it != smap_.end()) { return it->second; } else { - return e; + return GetRef(op); } } @@ -76,13 +75,13 @@ class IRSubstitue : public IRMutator { Stmt Substitute(Stmt stmt, const std::unordered_map& value_map) { if (value_map.size() == 0) return stmt; - return IRSubstitue(value_map).Mutate(stmt); + return IRSubstitue(value_map)(std::move(stmt)); } Expr Substitute(Expr expr, const std::unordered_map& value_map) { if (value_map.size() == 0) return expr; - return IRSubstitue(value_map).Mutate(expr); + return IRSubstitue(value_map)(std::move(expr)); } Stmt Substitute(Stmt stmt, const Map& value_map) { @@ -101,20 +100,20 @@ Expr Substitute(Expr expr, const Map& value_map) { return Substitute(expr, vmap); } -class VarTouchVisitor : public IRVisitor { +class VarTouchVisitor : public ExprVisitor { public: - void Visit(const ObjectRef& e) final { + void VisitExpr(const Expr& e) final { if (use_var_) return; - IRVisitor::Visit(e); + ExprVisitor::VisitExpr(e); } - void Visit_(const Variable* op) final { + void VisitExpr_(const Variable* op) final { Handle(op); } - void Visit_(const Load* op) final { + void VisitExpr_(const Load* op) final { Handle(op->buffer_var.get()); - IRVisitor::Visit_(op); + ExprVisitor::VisitExpr_(op); } virtual void Handle(const Variable* var) = 0; @@ -149,14 +148,14 @@ class ExprUseVSetVisitor : public VarTouchVisitor { bool ExprUseVar(const Expr& e, const Var& v) { ExprUseVarVisitor visitor(v.get()); - visitor.Visit(e); + visitor(e); return visitor.use_var_; } bool ExprUseVar(const Expr& e, const std::unordered_set& vset) { ExprUseVSetVisitor visitor(vset); - visitor.Visit(e); + visitor(e); return visitor.use_var_; } diff --git a/src/pass/skip_assert.cc b/src/pass/skip_assert.cc index 817416d9fd2c..47a158f1d760 100644 --- a/src/pass/skip_assert.cc +++ b/src/pass/skip_assert.cc @@ -19,22 +19,22 @@ #include #include -#include +#include namespace tvm { namespace ir { -class AssertSkipper : public IRMutator { +class AssertSkipper : public StmtMutator { public: - Stmt Mutate_(const AssertStmt* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const AssertStmt* op) final { + Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); return op->body; } }; Stmt SkipAssert(Stmt stmt) { - return AssertSkipper().Mutate(stmt); + return AssertSkipper()(std::move(stmt)); } LoweredFunc SkipAssert(LoweredFunc f) { From 42bf9e905f189469b6b9da883778d68280e9b50a Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 09:42:06 -0800 Subject: [PATCH 24/45] Migrate split_host_device --- src/pass/split_host_device.cc | 72 ++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc index f045c271456c..2a7c75e04eb5 100644 --- a/src/pass/split_host_device.cc +++ b/src/pass/split_host_device.cc @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include @@ -32,9 +32,9 @@ namespace tvm { namespace ir { // use/def analysis, also delete unreferenced lets -class IRUseDefAnalysis : public IRMutator { +class IRUseDefAnalysis : public StmtExprMutator { public: - Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); @@ -48,75 +48,77 @@ class IRUseDefAnalysis : public IRMutator { Expr value = op->value; if (visit_thread_extent_) { - value = this->Mutate(value); + value = this->VisitExpr(value); + } + Stmt body = this->VisitStmt(op->body); + if (value.same_as(op->value) && body.same_as(op->body)) { + return GetRef(op); } - Stmt body = this->Mutate(op->body); - if (value.same_as(op->value) && body.same_as(op->body)) return s; return AttrStmt::make(op->node, op->attr_key, value, body); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const LetStmt *op, const Stmt& s) final { + Stmt VisitStmt_(const LetStmt* op) final { this->HandleDef(op->var.get()); - Stmt body = this->Mutate(op->body); + Stmt body = this->VisitStmt(op->body); // eliminate unreferenced let if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) { return body; } else { - Expr value = this->Mutate(op->value); + Expr value = this->VisitExpr(op->value); if (body.same_as(op->body) && value.same_as(op->value)) { - return s; + return GetRef(op); } else { return LetStmt::make(op->var, value, body); } } } - Stmt Mutate_(const For *op, const Stmt& s) final { + Stmt VisitStmt_(const For* op) final { this->HandleDef(op->loop_var.get()); - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } - Stmt Mutate_(const Allocate *op, const Stmt& s) final { + Stmt VisitStmt_(const Allocate* op) final { this->HandleDef(op->buffer_var.get()); - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } - Stmt Mutate_(const Store *op, const Stmt& s) final { + Stmt VisitStmt_(const Store* op) final { this->HandleUse(op->buffer_var); - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } - Expr Mutate_(const Let *op, const Expr& e) final { + Expr VisitExpr_(const Let* op) final { this->HandleDef(op->var.get()); - Expr body = this->Mutate(op->body); + Expr body = this->VisitExpr(op->body); // eliminate unreferenced let if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) { return body; } else { - Expr value = this->Mutate(op->value); + Expr value = this->VisitExpr(op->value); if (body.same_as(op->body) && value.same_as(op->value)) { - return e; + return GetRef(op); } else { return Let::make(op->var, value, body); } } } - Expr Mutate_(const Variable *op, const Expr& e) final { - this->HandleUse(e); - return IRMutator::Mutate_(op, e); + Expr VisitExpr_(const Variable* op) final { + this->HandleUse(GetRef(op)); + return StmtExprMutator::VisitExpr_(op); } - Expr Mutate_(const Load *op, const Expr& e) final { + Expr VisitExpr_(const Load* op) final { this->HandleUse(op->buffer_var); - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } void HandleDef(const Variable* v) { @@ -154,20 +156,20 @@ class IRUseDefAnalysis : public IRMutator { std::unordered_map def_count_; }; -class HostDeviceSplitter : public IRMutator { +class HostDeviceSplitter : public StmtMutator { public: - Stmt Mutate_(const Allocate* op, const Stmt& s) final { + Stmt VisitStmt_(const Allocate* op) final { handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0); - return IRMutator::Mutate_(op, s); + return StmtMutator::VisitStmt_(op); } - Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || op->attr_key == attr::device_scope) { - return SplitDeviceFunc(s); + return SplitDeviceFunc(GetRef(op)); } - return IRMutator::Mutate_(op, s); + return StmtMutator::VisitStmt_(op); } Array Split(LoweredFunc f) { @@ -178,7 +180,7 @@ class HostDeviceSplitter : public IRMutator { name_ = f->name; ObjectPtr n = make_object(*f.operator->()); - n->body = this->Mutate(f->body); + n->body = operator()(f->body); n->func_type = kHostFunc; Array ret{LoweredFunc(n)}; for (LoweredFunc x : device_funcs_) { @@ -195,7 +197,7 @@ class HostDeviceSplitter : public IRMutator { // isolate the device function. IRUseDefAnalysis m; m.visit_thread_extent_ = false; - n->body = m.Mutate(body); + n->body = m(std::move(body)); n->name = os.str(); n->func_type = kDeviceFunc; n->thread_axis = m.thread_axis_; @@ -243,7 +245,7 @@ Array UndefinedVars(const Stmt& stmt, const Array& args) { for (Var arg : args) { m.use_count_[arg.get()] = 0; } - m.Mutate(stmt); + m(stmt); return m.undefined_; } From fdd5a74edbf2e6d1e9da450e711cdc954bd24973 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 10:45:57 -0800 Subject: [PATCH 25/45] Migrate ssa --- src/pass/ssa.cc | 90 ++++++++++++++++++++++++++----------------------- 1 file changed, 47 insertions(+), 43 deletions(-) diff --git a/src/pass/ssa.cc b/src/pass/ssa.cc index 37db29c58079..94f045d94561 100644 --- a/src/pass/ssa.cc +++ b/src/pass/ssa.cc @@ -24,8 +24,7 @@ * \file ssa.cc */ #include -#include -#include +#include #include #include #include @@ -34,29 +33,33 @@ namespace tvm { namespace ir { namespace { -class IRVerifySSA final : public IRVisitor { +class IRVerifySSA final : public StmtExprVisitor { public: bool is_ssa{true}; - void Visit(const ObjectRef& n) final { + void VisitExpr(const Expr& n) final { if (!is_ssa) return; - IRVisitor::Visit(n); + StmtExprVisitor::VisitExpr(n); } - void Visit_(const Let* op) final { + void VisitStmt(const Stmt& n) final { + if (!is_ssa) return; + StmtExprVisitor::VisitStmt(n); + } + void VisitExpr_(const Let* op) final { MarkDef(op->var.get()); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); } - void Visit_(const LetStmt* op) final { + void VisitStmt_(const LetStmt* op) final { MarkDef(op->var.get()); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } - void Visit_(const For* op) final { + void VisitStmt_(const For* op) final { MarkDef(op->loop_var.get()); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } - void Visit_(const Allocate* op) final { + void VisitStmt_(const Allocate* op) final { MarkDef(op->buffer_var.get()); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } private: @@ -70,31 +73,32 @@ class IRVerifySSA final : public IRVisitor { std::unordered_map defined_; }; -class IRConvertSSA final : public IRMutator { + +class IRConvertSSA final : public StmtExprMutator { public: - Expr Mutate_(const Variable* op, const Expr& e) final { + Expr VisitExpr_(const Variable* op) final { if (scope_.count(op)) { return scope_[op].back(); } else { - return e; + return GetRef(op); } } - Expr Mutate_(const Let* op, const Expr& e) final { + Expr VisitExpr_(const Let* op) final { const VarExpr& v = op->var; if (defined_.count(v.get())) { - Expr value = IRMutator::Mutate(op->value); + Expr value = this->VisitExpr(op->value); VarExpr new_var = Variable::make(v.dtype(), v->name_hint); scope_[v.get()].push_back(new_var); - Expr body = IRMutator::Mutate(op->body); + Expr body = this->VisitExpr(op->body); scope_[v.get()].pop_back(); return Let::make(new_var, value, body); } else { defined_.insert(v.get()); - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } } - Expr Mutate_(const Load* op, const Expr& e) final { - Expr expr = IRMutator::Mutate_(op, e); + Expr VisitExpr_(const Load* op) final { + Expr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (scope_.count(op->buffer_var.get())) { return Load::make( @@ -104,8 +108,8 @@ class IRConvertSSA final : public IRMutator { return expr; } } - Stmt Mutate_(const Store* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const Store* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(op->buffer_var.get())) { return Store::make( @@ -115,41 +119,41 @@ class IRConvertSSA final : public IRMutator { return stmt; } } - Stmt Mutate_(const LetStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const LetStmt* op) final { const VarExpr& v = op->var; if (defined_.count(v.get())) { - Expr value = IRMutator::Mutate(op->value); + Expr value = this->VisitExpr(op->value); VarExpr new_var = Variable::make(v.dtype(), v->name_hint); scope_[v.get()].push_back(new_var); - Stmt body = IRMutator::Mutate(op->body); + Stmt body = this->VisitStmt(op->body); scope_[v.get()].pop_back(); return LetStmt::make(new_var, value, body); } else { defined_.insert(v.get()); - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const For* op, const Stmt& s) final { + Stmt VisitStmt_(const For* op) final { const VarExpr& v = op->loop_var; if (defined_.count(v.get())) { VarExpr new_var = Variable::make(v.dtype(), v->name_hint); scope_[v.get()].push_back(new_var); - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); return For::make( new_var, op->min, op->extent, op->for_type, op->device_api, op->body); } else { defined_.insert(v.get()); - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const Allocate* op, const Stmt& s) final { + Stmt VisitStmt_(const Allocate* op) final { const VarExpr& v = op->buffer_var; if (defined_.count(v.get())) { VarExpr new_var = Variable::make(v.dtype(), v->name_hint); scope_[v.get()].push_back(new_var); - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); return Allocate::make( @@ -157,23 +161,23 @@ class IRConvertSSA final : public IRMutator { op->body, op->new_expr, op->free_function); } else { defined_.insert(v.get()); - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (const Variable* v = op->node.as()) { if (op->attr_key == attr::storage_scope) { const Allocate* alloc = op->body.as(); if (alloc && op->node.same_as(alloc->buffer_var)) { - Stmt new_alloc = Mutate(op->body); - if (new_alloc.same_as(op->body)) return s; + Stmt new_alloc = this->VisitStmt(op->body); + if (new_alloc.same_as(op->body)) return GetRef(op); alloc = new_alloc.as(); CHECK(alloc); return AttrStmt::make( alloc->buffer_var, op->attr_key, op->value, new_alloc); } } - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(v) && scope_[v].size() != 0) { return AttrStmt::make( @@ -182,7 +186,7 @@ class IRConvertSSA final : public IRMutator { return stmt; } } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } @@ -194,13 +198,13 @@ class IRConvertSSA final : public IRMutator { } // namespace bool VerifySSA(const Stmt& ir) { - IRVerifySSA v; - v.Visit(ir); - return v.is_ssa; + IRVerifySSA visitor; + visitor(ir); + return visitor.is_ssa; } Stmt ConvertSSA(Stmt stmt) { - return IRConvertSSA().Mutate(stmt); + return IRConvertSSA()(std::move(stmt)); } } // namespace ir From 2bae5173acdb45204f9bd4d5831ae2c7e506c17f Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 10:51:25 -0800 Subject: [PATCH 26/45] Migrate storage_access --- src/pass/coproc_sync.cc | 6 ++-- src/pass/storage_access.cc | 70 +++++++++++++++++++------------------- src/pass/storage_access.h | 18 +++++----- src/pass/storage_sync.cc | 2 +- 4 files changed, 48 insertions(+), 48 deletions(-) diff --git a/src/pass/coproc_sync.cc b/src/pass/coproc_sync.cc index aed94a58da0d..924305199628 100644 --- a/src/pass/coproc_sync.cc +++ b/src/pass/coproc_sync.cc @@ -95,7 +95,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor { } void Plan(const Stmt& stmt) { - this->Visit(stmt); + this->VisitStmt(stmt); PlanSync(scope_.back(), nullptr, true); if (sync_.size() == 0) { sync_[stmt.get()] = GetSync(coproc_name_ + ".coproc_sync"); @@ -219,12 +219,12 @@ class CoProcBarrierDetector : public StorageAccessVisitor { void PlanReadBarrier(const Stmt& stmt) { read_barrier_ = true; - this->Visit(stmt); + this->VisitStmt(stmt); PlanReadBarrier(scope_.back(), nullptr); } void PlanWriteBarrier(const Stmt& stmt) { read_barrier_ = false; - this->Visit(stmt); + this->VisitStmt(stmt); PlanWriteBarrier(scope_.back(), nullptr); } diff --git a/src/pass/storage_access.cc b/src/pass/storage_access.cc index bf8d4e020521..0d594046e5df 100644 --- a/src/pass/storage_access.cc +++ b/src/pass/storage_access.cc @@ -21,7 +21,6 @@ * \file storage_access.cc */ #include -#include #include #include #include @@ -32,7 +31,7 @@ namespace tvm { namespace ir { -void StorageAccessVisitor::Visit_(const Load* op) { +void StorageAccessVisitor::VisitExpr_(const Load* op) { const Variable* buf = op->buffer_var.as(); StorageScope scope = GetScope(buf); if (Enabled(buf, scope)) { @@ -47,10 +46,10 @@ void StorageAccessVisitor::Visit_(const Load* op) { curr_stmt_.access.emplace_back(std::move(e)); } // traverse child - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); } -void StorageAccessVisitor::Visit_(const Store* op) { +void StorageAccessVisitor::VisitStmt_(const Store* op) { allow_append_ = true; CHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; @@ -67,7 +66,7 @@ void StorageAccessVisitor::Visit_(const Store* op) { curr_stmt_.access.emplace_back(std::move(e)); } // traverse child - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); // push to the scope scope_.back().push_back(curr_stmt_); // clear access entry. @@ -75,11 +74,11 @@ void StorageAccessVisitor::Visit_(const Store* op) { allow_append_ = false; } -void StorageAccessVisitor::Visit_(const Evaluate* op) { +void StorageAccessVisitor::VisitStmt_(const Evaluate* op) { allow_append_ = true; CHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); // push to the scope if (curr_stmt_.access.size() != 0) { scope_.back().push_back(curr_stmt_); @@ -88,17 +87,17 @@ void StorageAccessVisitor::Visit_(const Evaluate* op) { allow_append_ = false; } -void StorageAccessVisitor::Visit_(const AttrStmt* op) { +void StorageAccessVisitor::VisitStmt_(const AttrStmt* op) { if (op->attr_key == attr::storage_scope) { const Variable* buf = op->node.as(); storage_scope_[buf] = StorageScope::make(op->value.as()->value); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } else if (op->attr_key == attr::double_buffer_write) { CHECK(double_buffer_write_ == nullptr); double_buffer_write_ = op->node.as(); scope_.push_back(std::vector()); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); StmtEntry s; s.stmt = op; s.access = Summarize(std::move(scope_.back()), nullptr); @@ -115,7 +114,7 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) { } else if (op->attr_key == attr::coproc_scope) { IterVar iv = Downcast(op->node); env_threads_.push_back(iv); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); env_threads_.CopyOnWrite()->data.pop_back(); } else if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); @@ -123,23 +122,23 @@ void StorageAccessVisitor::Visit_(const AttrStmt* op) { if (!in_device_env_) { in_device_env_ = true; scope_.push_back(std::vector()); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); // no need to take the result as the thread barrier automatically syncs. Summarize(std::move(scope_.back()), nullptr); in_device_env_ = false; scope_.pop_back(); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } env_threads_.CopyOnWrite()->data.pop_back(); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } } -void StorageAccessVisitor::Visit_(const For* op) { +void StorageAccessVisitor::VisitStmt_(const For* op) { scope_.push_back(std::vector()); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); StmtEntry s; s.stmt = op; s.access = Summarize(std::move(scope_.back()), op); @@ -161,11 +160,11 @@ void StorageAccessVisitor::Visit_(const For* op) { } } -void StorageAccessVisitor::Visit_(const IfThenElse* op) { +void StorageAccessVisitor::VisitStmt_(const IfThenElse* op) { ++condition_counter_; - this->Visit(op->condition); + this->VisitExpr(op->condition); scope_.push_back(std::vector()); - this->Visit(op->then_case); + this->VisitStmt(op->then_case); StmtEntry s; s.stmt = op; s.access = Summarize(std::move(scope_.back()), nullptr); @@ -180,10 +179,10 @@ void StorageAccessVisitor::Visit_(const IfThenElse* op) { --condition_counter_; } -void StorageAccessVisitor::Visit_(const Call* op) { +void StorageAccessVisitor::VisitExpr_(const Call* op) { if (op->is_intrinsic(intrinsic::tvm_address_of)) { const Load *l = op->args[0].as(); - IRVisitor::Visit_(l); + StmtExprVisitor::VisitExpr_(l); } else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); @@ -211,7 +210,7 @@ void StorageAccessVisitor::Visit_(const Call* op) { curr_stmt_.access.emplace_back(e); } } - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { CHECK(allow_append_); const std::string& s = op->args[0].as()->value; @@ -224,7 +223,7 @@ void StorageAccessVisitor::Visit_(const Call* op) { curr_stmt_.access.emplace_back(std::move(e)); } } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); } } @@ -236,11 +235,12 @@ StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const { return it->second; } -class StorageAccessInfoLower : public IRMutator { + +class StorageAccessInfoLower : public StmtExprMutator { public: - Stmt Mutate_(const Allocate* op, const Stmt& s) final { + Stmt VisitStmt_(const Allocate* op) final { // Lower allocate to device allocate when needed. - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); // For special memory, remove allocate, or use head expr auto it = storage_info_.find(op->buffer_var.get()); @@ -259,7 +259,7 @@ class StorageAccessInfoLower : public IRMutator { return stmt; } } - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::storage_scope) { const Variable* buf = op->node.as(); StorageScope scope = StorageScope::make(op->value.as()->value); @@ -270,26 +270,26 @@ class StorageAccessInfoLower : public IRMutator { CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string(); } storage_info_[buf] = e; - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Expr Mutate_(const Call* op, const Expr &e) final { + Expr VisitExpr_(const Call* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { - return MakeAccessPtr(op, e); + return MakeAccessPtr(op); } else { - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } } private: // tvm_access_ptr - Expr MakeAccessPtr(const Call* op, const Expr& e) { + Expr MakeAccessPtr(const Call* op) { // Specially handle the buffer packed intrinsic - Expr expr = IRMutator::Mutate_(op, e); + Expr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); @@ -337,7 +337,7 @@ class StorageAccessInfoLower : public IRMutator { }; Stmt LowerStorageAccessInfo(Stmt stmt) { - return StorageAccessInfoLower().Mutate(stmt); + return StorageAccessInfoLower()(std::move(stmt)); } LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) { diff --git a/src/pass/storage_access.h b/src/pass/storage_access.h index 302ca929581d..12bf9f369de3 100644 --- a/src/pass/storage_access.h +++ b/src/pass/storage_access.h @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include #include "../runtime/thread_storage_scope.h" @@ -40,7 +40,7 @@ using runtime::StorageRank; /*! * \brief Base class of storage access analysis */ -class StorageAccessVisitor : public IRVisitor { +class StorageAccessVisitor : public StmtExprVisitor { public: /*! \brief Storage access type */ enum AccessType { @@ -76,13 +76,13 @@ class StorageAccessVisitor : public IRVisitor { std::vector access; }; // override visitor pattern - void Visit_(const Load* op) final; - void Visit_(const Store* op) final; - void Visit_(const Evaluate* op) final; - void Visit_(const AttrStmt* op) final; - void Visit_(const For* op) final; - void Visit_(const IfThenElse* op) final; - void Visit_(const Call* op) final; + void VisitExpr_(const Load* op) final; + void VisitStmt_(const Store* op) final; + void VisitStmt_(const Evaluate* op) final; + void VisitStmt_(const AttrStmt* op) final; + void VisitStmt_(const For* op) final; + void VisitStmt_(const IfThenElse* op) final; + void VisitExpr_(const Call* op) final; protected: StorageAccessVisitor() { diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc index 0f8bef8383f2..dd6d2a1b7ee0 100644 --- a/src/pass/storage_sync.cc +++ b/src/pass/storage_sync.cc @@ -363,7 +363,7 @@ class ThreadSyncInserter : public IRMutator { Stmt ThreadSync(Stmt stmt, std::string storage_scope) { StorageScope sync_scope = StorageScope::make(storage_scope); ThreadSyncPlanner planner(sync_scope); - planner.Visit(stmt); + planner(stmt); return ThreadSyncInserter(sync_scope, planner.syncs_inserted_).Mutate(stmt); } From 0397c7612911c2b3db2af293933b45d3f76ea663 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 11:00:51 -0800 Subject: [PATCH 27/45] Migrate storage_rewrite --- src/pass/storage_rewrite.cc | 181 ++++++++++++++++++------------------ 1 file changed, 93 insertions(+), 88 deletions(-) diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 01c6f983d692..c820c477e128 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -24,8 +24,7 @@ */ #include #include -#include -#include +#include #include #include #include @@ -54,7 +53,7 @@ using runtime::StorageScope; // The storage need to be kept alive between allocate and last access. // The free point is only inserted at the same scope of allocate. // -class LinearAccessPatternFinder final : public IRVisitor { +class LinearAccessPatternFinder final : public StmtExprVisitor { public: /*! \brief record the touch hist of statment. */ struct StmtEntry { @@ -78,7 +77,7 @@ class LinearAccessPatternFinder final : public IRVisitor { const Allocate* alloc{nullptr}; }; - void Visit_(const Allocate* op) final { + void VisitStmt_(const Allocate* op) final { size_t level = scope_.size(); const Variable* buf = op->buffer_var.get(); auto it = alloc_info_.find(buf); @@ -86,12 +85,12 @@ class LinearAccessPatternFinder final : public IRVisitor { CHECK(it->second.alloc == nullptr); it->second.alloc = op; it->second.level = level; - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } - void Visit_(const Store* op) final { + void VisitStmt_(const Store* op) final { scope_.push_back(StmtEntry()); // visit subexpr - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); // Add write access. const Variable* buf = op->buffer_var.get(); auto it = alloc_info_.find(buf); @@ -106,10 +105,10 @@ class LinearAccessPatternFinder final : public IRVisitor { linear_seq_.push_back(e); } } - void Visit_(const Evaluate* op) final { + void VisitStmt_(const Evaluate* op) final { scope_.push_back(StmtEntry()); // visit subexpr - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); StmtEntry e = scope_.back(); scope_.pop_back(); if (e.touched.size() != 0) { @@ -117,9 +116,9 @@ class LinearAccessPatternFinder final : public IRVisitor { linear_seq_.push_back(e); } } - void Visit_(const Load* op) final { + void VisitExpr_(const Load* op) final { // Add write access. - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); const Variable* buf = op->buffer_var.get(); auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { @@ -128,15 +127,15 @@ class LinearAccessPatternFinder final : public IRVisitor { scope_[it->second.level].touched.push_back(buf); } } - void Visit_(const Call* op) final { + void VisitExpr_(const Call* op) final { if (op->is_intrinsic(intrinsic::tvm_address_of)) { const Load* l = op->args[0].as(); - this->Visit(l->index); + this->VisitExpr(l->index); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); } } - void Visit_(const Variable* buf) final { + void VisitExpr_(const Variable* buf) final { // Directly reference to the variable count as a read. auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { @@ -153,7 +152,7 @@ class LinearAccessPatternFinder final : public IRVisitor { int64_t begin_index = static_cast(linear_seq_.size()); // before scope. linear_seq_.push_back(e); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); // after scope. e.touched = std::move(scope_.back().touched); scope_.pop_back(); @@ -165,7 +164,7 @@ class LinearAccessPatternFinder final : public IRVisitor { CHECK_NE(end_index, 0U); linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; } - void Visit_(const AttrStmt* op) final { + void VisitStmt_(const AttrStmt* op) final { // Only record the outer most thread extent. if (op->attr_key == attr::thread_extent && !in_thread_env_) { in_thread_env_ = true; @@ -179,20 +178,20 @@ class LinearAccessPatternFinder final : public IRVisitor { const Variable* buf = op->node.as(); alloc_info_[buf].storage_scope = StorageScope::make(op->value.as()->value); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } } - void Visit_(const IfThenElse* op) final { + void VisitStmt_(const IfThenElse* op) final { VisitNewScope(op); } - void Visit_(const For* op) final { + void VisitStmt_(const For* op) final { VisitNewScope(op); } - void Visit_(const AssertStmt* op) final { + void VisitStmt_(const AssertStmt* op) final { VisitNewScope(op); } @@ -234,7 +233,7 @@ class LinearAccessPatternFinder final : public IRVisitor { // // The code after inplace transformation is no longer idempotent. // -class InplaceOpVerifier : public IRVisitor { +class InplaceOpVerifier : public StmtExprVisitor { public: bool Check(const Object* stmt, const Variable* dst, @@ -243,58 +242,62 @@ class InplaceOpVerifier : public IRVisitor { src_ = src; result_ = true; if (stmt->IsInstance()) { - Visit_(static_cast(stmt)); + VisitStmt_(static_cast(stmt)); } else if (stmt->IsInstance()) { - Visit_(static_cast(stmt)); + VisitStmt_(static_cast(stmt)); } else if (stmt->IsInstance()) { - Visit_(static_cast(stmt)); + VisitStmt_(static_cast(stmt)); } else if (stmt->IsInstance()) { - Visit_(static_cast(stmt)); + VisitStmt_(static_cast(stmt)); } else { return false; } return result_; } - using IRVisitor::Visit_; + using StmtExprVisitor::VisitStmt_; - void Visit(const ObjectRef& e) final { + void VisitStmt(const Stmt& n) final { if (!result_) return; - IRVisitor::Visit(e); + StmtExprVisitor::VisitStmt(n); + } + void VisitExpr(const Expr& n) final { + if (!result_) return; + StmtExprVisitor::VisitExpr(n); } - void Visit_(const Variable* op) final { + void VisitExpr_(const Variable* op) final { // assume all opaque access is unsafe if (op == dst_ || op == src_) { result_ = false; return; } } - void Visit_(const Store* op) final { + void VisitStmt_(const Store* op) final { ++mem_nest_; - this->Visit(op->index); + this->VisitExpr(op->index); --mem_nest_; if (op->buffer_var.get() == dst_) { store_ = op; - this->Visit(op->value); - this->Visit(op->predicate); + this->VisitExpr(op->value); + this->VisitExpr(op->predicate); store_ = nullptr; } else { - this->Visit(op->value); - this->Visit(op->predicate); + this->VisitExpr(op->value); + this->VisitExpr(op->predicate); } } - void Visit_(const AttrStmt* op) final { + void VisitStmt_(const AttrStmt* op) final { // always reject extern code if (op->attr_key == attr::extern_scope || op->attr_key == attr::volatile_scope) { result_ = false; return; } - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } - void Visit_(const Load* op) final { + void VisitExpr_(const Load* op) final { const Variable* buf = op->buffer_var.get(); // cannot read from dst_ (no reduction) if (buf == dst_) { @@ -312,7 +315,7 @@ class InplaceOpVerifier : public IRVisitor { } } ++mem_nest_; - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); --mem_nest_; } @@ -332,7 +335,7 @@ class InplaceOpVerifier : public IRVisitor { }; // Planner to plan and rewrite memory allocation. -class StoragePlanRewriter : public IRMutator { +class StoragePlanRewriter : public StmtExprMutator { public: using StmtEntry = LinearAccessPatternFinder::StmtEntry; using AllocEntry = LinearAccessPatternFinder::AllocEntry; @@ -341,12 +344,12 @@ class StoragePlanRewriter : public IRMutator { detect_inplace_ = detect_inplace; // plan the rewrite LinearAccessPatternFinder finder; - finder.Visit(stmt); + finder(stmt); this->LivenessAnalysis(finder.linear_seq_); this->PlanMemory(finder.linear_seq_, finder.alloc_info_); this->PrepareNewAlloc(); // start rewrite - stmt = this->Mutate(stmt); + stmt = operator()(std::move(stmt)); if (attach_map_.count(nullptr)) { std::vector nest; for (StorageEntry* e : attach_map_.at(nullptr)) { @@ -363,8 +366,8 @@ class StoragePlanRewriter : public IRMutator { } return stmt; } - Stmt Mutate_(const Store* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const Store* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); auto it = alloc_map_.find(op->buffer_var.get()); if (it == alloc_map_.end()) return stmt; @@ -373,8 +376,8 @@ class StoragePlanRewriter : public IRMutator { RemapIndex(op->value.dtype(), op->index, it->second), op->predicate); } - Expr Mutate_(const Load* op, const Expr& e) final { - Expr expr = IRMutator::Mutate_(op, e); + Expr VisitExpr_(const Load* op) final { + Expr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); auto it = alloc_map_.find(op->buffer_var.get()); if (it == alloc_map_.end()) return expr; @@ -383,7 +386,7 @@ class StoragePlanRewriter : public IRMutator { RemapIndex(op->dtype, op->index, it->second), op->predicate); } - Expr Mutate_(const Variable* op, const Expr& e) final { + Expr VisitExpr_(const Variable* op) final { auto it = alloc_map_.find(op); if (it != alloc_map_.end()) { if (it->second->bits_offset != 0) { @@ -391,79 +394,81 @@ class StoragePlanRewriter : public IRMutator { } return it->second->alloc_var; } else { - return e; + return GetRef(op); } } - Expr Mutate_(const Call* op, const Expr& e) final { + Expr VisitExpr_(const Call* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const Variable* buffer = op->args[1].as(); auto it = alloc_map_.find(buffer); - if (it == alloc_map_.end()) return IRMutator::Mutate_(op, e); - const StorageEntry* se = it->second; - Expr offset = Mutate(op->args[2]); - Expr extent = Mutate(op->args[3]); - uint64_t elem_bits = dtype.bits() * dtype.lanes(); - CHECK_EQ(se->bits_offset % elem_bits, 0U); - if (se->bits_offset != 0) { - offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset; - } - return Call::make( - op->dtype, op->name, - {op->args[0], se->alloc_var, offset, extent, op->args[4]}, - op->call_type); + if (it == alloc_map_.end()) { + return StmtExprMutator::VisitExpr_(op); + } + const StorageEntry* se = it->second; + Expr offset = this->VisitExpr(op->args[2]); + Expr extent = this->VisitExpr(op->args[3]); + uint64_t elem_bits = dtype.bits() * dtype.lanes(); + CHECK_EQ(se->bits_offset % elem_bits, 0U); + if (se->bits_offset != 0) { + offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset; + } + return Call::make( + op->dtype, op->name, + {op->args[0], se->alloc_var, offset, extent, op->args[4]}, + op->call_type); } else { - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } } - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::storage_scope) { - return this->Mutate(op->body); + return this->VisitStmt(op->body); } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || attr::IsPragmaKey(op->attr_key)) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); return AttrStmt::make( op->node, op->attr_key, op->value, MakeAttach(svec, op->body)); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } else if (op->attr_key == attr::volatile_scope) { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); auto it = alloc_map_.find(op->node.as()); if (it == alloc_map_.end()) return stmt; return AttrStmt::make( it->second->alloc_var, op->attr_key, op->value, op->body); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const For* op, const Stmt& s) final { + Stmt VisitStmt_(const For* op) final { CHECK(op->for_type != ForType::Vectorized) << "VectorizeLoop before LiftStorageAlloc"; // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); return For::make( op->loop_var, op->min, op->extent, op->for_type, op->device_api, MakeAttach(svec, op->body)); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const Allocate* op, const Stmt& s) final { - return this->Mutate(op->body); + Stmt VisitStmt_(const Allocate* op) final { + return this->VisitStmt(op->body); } private: @@ -929,28 +934,28 @@ class StoragePlanRewriter : public IRMutator { // Turn alloc into vector alloc // if all its access is the same vector type. -class VectorAllocRewriter : public IRMutator { +class VectorAllocRewriter : public StmtExprMutator { public: - Expr Mutate_(const Load* op, const Expr& e) final { + Expr VisitExpr_(const Load* op) final { UpdateTypeMap(op->buffer_var.get(), op->dtype); - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } - Stmt Mutate_(const Store* op, const Stmt& s) final { + Stmt VisitStmt_(const Store* op) final { UpdateTypeMap(op->buffer_var.get(), op->value.dtype()); - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } - Expr Mutate_(const Call* op, const Expr& e) final { + Expr VisitExpr_(const Call* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { DataType dtype = op->args[0].dtype(); const Variable* buffer = op->args[1].as(); UpdateTypeMap(buffer, dtype); } - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } - Stmt Mutate_(const Allocate* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const Allocate* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); const auto& tvec = acc_map_[op->buffer_var.get()]; @@ -989,7 +994,7 @@ class VectorAllocRewriter : public IRMutator { LoweredFunc PointerValueTypeRewrite(LoweredFunc f) { auto n = make_object(*f.operator->()); VectorAllocRewriter rewriter; - n->body = rewriter.Mutate(n->body); + n->body = rewriter(n->body); for (Var arg : f->args) { if (arg.dtype().is_handle()) { const auto& tvec = rewriter.acc_map_[arg.get()]; @@ -1010,8 +1015,8 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f) { } Stmt StorageRewrite(Stmt stmt) { - stmt = StoragePlanRewriter().Rewrite(stmt, true); - return VectorAllocRewriter().Mutate(stmt); + stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true); + return VectorAllocRewriter()(std::move(stmt)); } } // namespace ir } // namespace tvm From d6b476ba2c4a51017688eb92d029492e12143f5a Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 11:11:26 -0800 Subject: [PATCH 28/45] Migrate tensor_core --- src/pass/tensor_core.cc | 109 +++++++++++++++++++--------------------- 1 file changed, 52 insertions(+), 57 deletions(-) diff --git a/src/pass/tensor_core.cc b/src/pass/tensor_core.cc index 8dcc0e49e119..a3890cde773e 100644 --- a/src/pass/tensor_core.cc +++ b/src/pass/tensor_core.cc @@ -24,8 +24,7 @@ #include #include #include -#include -#include +#include #include #include #include @@ -73,7 +72,7 @@ Expr unpack_type_cast(const Expr &input, const DataType &target_type) { // MMAMatcher matches C = Cast(A)*Cast(B)+C, // where A & B are fp16/int8 local buffers, // and C is fp32/int32 local buffer. -class MMAMatcher: public IRVisitor { +class MMAMatcher: public StmtVisitor { public: explicit MMAMatcher(Map extern_buffer) { for (auto kv : extern_buffer) { @@ -84,22 +83,21 @@ class MMAMatcher: public IRVisitor { buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = bi; } } - using IRVisitor::Visit_; - void Visit_(const AttrStmt* op) final { + void VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::pragma_tensor_core) { tensor_core_on_ = true; - IRVisitor::Visit_(op); + StmtVisitor::VisitStmt_(op); } else if (op->attr_key == attr::realize_scope) { storage_scope_[op->node.get()] = op->value.as()->value; - Visit(op->body); + this->VisitStmt(op->body); } else { - IRVisitor::Visit_(op); + StmtVisitor::VisitStmt_(op); } } - void Visit_(const Provide* op) final { - IRVisitor::Visit_(op); + void VisitStmt_(const Provide* op) final { + StmtVisitor::VisitStmt_(op); auto it = buf_map_.find(TensorKey{op->func, op->value_index}); if (it == buf_map_.end()) { return; @@ -113,19 +111,19 @@ class MMAMatcher: public IRVisitor { } } - void Visit_(const Realize* op) final { + void VisitStmt_(const Realize* op) final { TensorKey key{op->func, op->value_index}; if (buf_map_.count(key)) { if (!buf_map_.at(key).external) { return; } - Visit(op->body); + this->VisitStmt(op->body); } else { BufferInfo bi; bi.name = key.GetName(); bi.dtype = op->dtype; buf_map_[key] = bi; - Visit(op->body); + this->VisitStmt(op->body); buf_map_[key].released = true; } } @@ -236,12 +234,11 @@ class MMAMatcher: public IRVisitor { // BodyVisitor visits the body stmt of original ComputeOp // to get the access indices of input matrices, // if it is recognized as matrix multiply. -class BodyVisitor : public IRVisitor { +class BodyVisitor : public StmtExprVisitor { public: BodyVisitor() {} - using IRVisitor::Visit_; - void Visit_(const Reduce* op) final { + void VisitExpr_(const Reduce* op) final { auto* comm_add = op->combiner->result[0].as(); if (comm_add == nullptr || op->combiner->result.size() > 1) { return; @@ -254,12 +251,12 @@ class BodyVisitor : public IRVisitor { } tensorcore_candidate_ = true; - IRVisitor::Visit(source); + StmtExprVisitor::VisitExpr(source); } } - void Visit_(const Call* op) final { - IRVisitor::Visit_(op); + void VisitExpr_(const Call* op) final { + StmtExprVisitor::VisitExpr_(op); args_.insert(std::make_pair(op->name, op->args)); } @@ -298,7 +295,7 @@ class ScheduleAnalyser { BodyVisitor body_visitor; for (Expr expr : compute->body) { - body_visitor.Visit(expr); + body_visitor(expr); } if (!body_visitor.tensorcore_candidate_) { continue; @@ -370,12 +367,11 @@ class ScheduleAnalyser { // IndexVisitor visits access index of fragment // to record variable for loop scaling -class IndexVisitor : public IRVisitor { +class IndexVisitor : public StmtExprVisitor { public: IndexVisitor() {} - using IRVisitor::Visit_; - void Visit_(const Variable* op) final { + void VisitExpr_(const Variable* op) final { loop_scaling_.insert(std::make_pair(op, scaling_factor_)); } @@ -389,7 +385,7 @@ class IndexVisitor : public IRVisitor { // BufferAnalyser gets buffer info, // e.g. thread tile and warp tile, for TensorCore CodeGen -class BufferAnalyser : public IRVisitor { +class BufferAnalyser : public StmtExprVisitor { public: explicit BufferAnalyser(Map extern_buffer, const ScheduleAnalyser &schedule_analyser, @@ -407,9 +403,8 @@ class BufferAnalyser : public IRVisitor { buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = bi; } } - using IRVisitor::Visit_; - void Visit_(const AttrStmt* op) final { + void VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::thread_extent) { if (const IntImm* value = op->value.as()) { thread_extent_.insert( @@ -417,10 +412,10 @@ class BufferAnalyser : public IRVisitor { op->node.as()->var->name_hint, value->value)); } - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } else if (op->attr_key == attr::realize_scope) { storage_scope_[op->node.get()] = op->value.as()->value; - Visit(op->body); + this->VisitStmt(op->body); } else if (op->attr_key == attr::buffer_dim_align) { Tensor tensor = Downcast(op->node); const Call* tuple = op->value.as(); @@ -432,14 +427,14 @@ class BufferAnalyser : public IRVisitor { } vinfo[dim].align_factor = tuple->args[1].as()->value; vinfo[dim].align_offset = tuple->args[2].as()->value; - Visit(op->body); + this->VisitStmt(op->body); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } } - void Visit_(const Provide* op) final { - IRVisitor::Visit_(op); + void VisitStmt_(const Provide* op) final { + StmtExprVisitor::VisitStmt_(op); TensorKey key{op->func, op->value_index}; auto it = buf_map_.find(key); CHECK(it != buf_map_.end()) @@ -503,7 +498,7 @@ class BufferAnalyser : public IRVisitor { } auto index = rel_index[i]; auto simplified_index = ir::Simplify(index); - index_visitor.Visit(simplified_index); + index_visitor(simplified_index); } std::string input_name = simplify_name(bi.name); @@ -550,8 +545,8 @@ class BufferAnalyser : public IRVisitor { } } - void Visit_(const Call* op) final { - IRVisitor::Visit_(op); + void VisitExpr_(const Call* op) final { + StmtExprVisitor::VisitExpr_(op); if (op->call_type == Call::Halide) { TensorKey key{op->func, op->value_index}; auto it = buf_map_.find(key); @@ -606,16 +601,16 @@ class BufferAnalyser : public IRVisitor { } auto index = rel_index[i]; auto simplified_index = ir::Simplify(index); - index_visitor.Visit(simplified_index); + index_visitor(simplified_index); } } } - void Visit_(const Realize* op) final { + void VisitStmt_(const Realize* op) final { TensorKey key{op->func, op->value_index}; if (buf_map_.count(key)) { CHECK(buf_map_.at(key).external); - Visit(op->body); + this->VisitStmt(op->body); } else { // create a buffer entry BufferInfo bi; @@ -653,7 +648,7 @@ class BufferAnalyser : public IRVisitor { bi.shape = shape; buf_map_[key] = bi; - Visit(op->body); + this->VisitStmt(op->body); buf_map_[key].released = true; } } @@ -761,12 +756,12 @@ class BufferAnalyser : public IRVisitor { }; // ThreadIdxMutator does the thread index unification inside a warp -class ThreadIdxMutator : public IRMutator { +class ThreadIdxMutator : public StmtExprMutator { public: explicit ThreadIdxMutator(Expr warp_y): warp_y_(warp_y) {} - Expr Mutate_(const Variable* op, const Expr& olde) final { - Expr expr = IRMutator::Mutate_(op, olde); + Expr VisitExpr_(const Variable* op) final { + Expr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (op != nullptr) { if (op->name_hint == "threadIdx.x") { @@ -788,7 +783,7 @@ class ThreadIdxMutator : public IRMutator { // TensorCoreIRMutator mutates the AST for TensorCore CodeGen // based on tensor core intrinsics -class TensorCoreIRMutator : public IRMutator { +class TensorCoreIRMutator : public StmtExprMutator { public: explicit TensorCoreIRMutator(const ScheduleAnalyser &schedule_analyser, const BufferAnalyser &buffer_analyser) @@ -803,10 +798,10 @@ class TensorCoreIRMutator : public IRMutator { warp_tile_(buffer_analyser.warp_tile_), warp_threads_y_(buffer_analyser.warp_threads_y_) {} - Stmt Mutate_(const Realize* op, const Stmt& s) final { + Stmt VisitStmt_(const Realize* op) final { TensorKey key{op->func, op->value_index}; bounds_[key] = op->bounds; - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (op != nullptr) { if (!frag_reg_.count(key.GetName())) { @@ -833,8 +828,8 @@ class TensorCoreIRMutator : public IRMutator { return stmt; } - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const AttrStmt* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); if (op->attr_key == attr::realize_scope) { auto node = op->node.as(); if (node != nullptr) { @@ -846,7 +841,7 @@ class TensorCoreIRMutator : public IRMutator { CHECK(it != matrix_abc_.end()) << "Cannot find matrix info for " << node->name; auto matrix_abc = "wmma." + it->second; - Stmt body = Mutate(op->body); + Stmt body = this->VisitStmt(op->body); return AttrStmt::make(op->node, op->attr_key, matrix_abc, @@ -856,8 +851,8 @@ class TensorCoreIRMutator : public IRMutator { return stmt; } - Stmt Mutate_(const Provide* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const Provide* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); auto it = mma_sync_.find(op); if (it != mma_sync_.end()) { const auto &operands = it->second; @@ -941,7 +936,7 @@ class TensorCoreIRMutator : public IRMutator { // thread index unification inside a warp Expr warp_y = IntImm::make(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); - Expr mutated_value = thread_idx_mutator.Mutate(op->value); + Expr mutated_value = thread_idx_mutator(op->value); Expr src = Call::make(value->dtype, "&", {mutated_value}, @@ -991,7 +986,7 @@ class TensorCoreIRMutator : public IRMutator { // thread index unification inside a warp Expr warp_y = IntImm::make(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); - dst = thread_idx_mutator.Mutate(dst); + dst = thread_idx_mutator(dst); dst = Call::make(DataType::Handle(), "&", {dst}, @@ -1020,8 +1015,8 @@ class TensorCoreIRMutator : public IRMutator { return stmt; } - Stmt Mutate_(const For* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const For* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (op != nullptr) { auto it = loop_scaling_.find(op->loop_var.get()); @@ -1177,7 +1172,7 @@ Stmt RewriteForTensorCore(Stmt stmt, } MMAMatcher mma_matcher(extern_buffer); - mma_matcher.Visit(stmt); + mma_matcher(stmt); if (!mma_matcher.Matched()) { return stmt; } @@ -1189,12 +1184,12 @@ Stmt RewriteForTensorCore(Stmt stmt, BufferAnalyser buffer_analyser(extern_buffer, schedule_analyser, mma_matcher); - buffer_analyser.Visit(stmt); + buffer_analyser(stmt); if (!buffer_analyser.QualifiedForTensorCore()) { return stmt; } - return TensorCoreIRMutator(schedule_analyser, buffer_analyser).Mutate(stmt); + return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt)); } } // namespace ir From 651885045690c6e0a1056e2243512dcd41445e5b Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 11:17:11 -0800 Subject: [PATCH 29/45] Migrate unroll_loop --- src/pass/unroll_loop.cc | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc index c56944406bd4..9fc87f3a0d6b 100644 --- a/src/pass/unroll_loop.cc +++ b/src/pass/unroll_loop.cc @@ -24,7 +24,7 @@ // Unrolls the loop as in Halide pipeline. #include #include -#include +#include #include #include #include @@ -33,7 +33,7 @@ namespace tvm { namespace ir { -class LoopUnroller : public IRMutator { +class LoopUnroller : public StmtExprMutator { public: explicit LoopUnroller(int auto_max_step, int auto_max_depth, @@ -45,12 +45,12 @@ class LoopUnroller : public IRMutator { explicit_unroll_(explicit_unroll) { } - Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == "pragma_auto_unroll_max_step") { int value = 0; CHECK(arith::GetConstInt(op->value, &value)); std::swap(value, auto_max_step_); - Stmt ret = this->Mutate(op->body); + Stmt ret = this->VisitStmt(op->body); std::swap(value, auto_max_step_); return ret; } else if (op->attr_key == "pragma_unroll_explicit") { @@ -58,16 +58,16 @@ class LoopUnroller : public IRMutator { CHECK(arith::GetConstInt(op->value, &value)); bool explicit_unroll = value; std::swap(explicit_unroll, explicit_unroll_); - Stmt ret = this->Mutate(op->body); + Stmt ret = this->VisitStmt(op->body); std::swap(explicit_unroll, explicit_unroll_); return ret; } else { - return IRMutator::Mutate_(op, stmt); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const For* op, const Stmt& s) { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const For* op) { + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); int value = GetExtent(op); // condition for auto unroll @@ -110,18 +110,18 @@ class LoopUnroller : public IRMutator { } } - Stmt Mutate_(const Store* op, const Stmt& stmt) final { + Stmt VisitStmt_(const Store* op) final { ++step_count_; - return IRMutator::Mutate_(op, stmt); + return StmtExprMutator::VisitStmt_(op); } - Stmt Mutate_(const Evaluate* op, const Stmt& stmt) final { + Stmt VisitStmt_(const Evaluate* op) final { ++step_count_; - return IRMutator::Mutate_(op, stmt); + return StmtExprMutator::VisitStmt_(op); } - Stmt Mutate_(const Block* op, const Stmt& stmt) final { - Stmt first = this->Mutate(op->first); + Stmt VisitStmt_(const Block* op) final { + Stmt first = this->VisitStmt(op->first); // cleanup state int step_count = step_count_; int unroll_depth = unroll_depth_; @@ -130,13 +130,13 @@ class LoopUnroller : public IRMutator { unroll_depth_ = 0; normal_loop_depth_ = 0; // work on rest part - Stmt rest = this->Mutate(op->rest); + Stmt rest = this->VisitStmt(op->rest); step_count_ += step_count; normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_); unroll_depth_ = std::max(unroll_depth_, unroll_depth); if (first.same_as(op->first) && rest.same_as(op->rest)) { - return stmt; + return GetRef(op); } else { return Block::make(first, rest); } @@ -204,7 +204,7 @@ Stmt UnrollLoop(Stmt stmt, auto_max_step, auto_max_depth, auto_max_extent, - explicit_unroll).Mutate(stmt); + explicit_unroll)(stmt); if (!ret.same_as(stmt)) { return ConvertSSA(ret); } else { From ef1d1e85e55ab5c8c6ecc290230e3b3ddc27083f Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 11:25:09 -0800 Subject: [PATCH 30/45] Migrate vectorize --- src/pass/vectorize_loop.cc | 258 ++++++++++++++++++------------------- 1 file changed, 127 insertions(+), 131 deletions(-) diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index 94639f7be363..c22243cc8f93 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -23,7 +23,7 @@ // Loop vectorizer as in Halide pipeline. #include #include -#include +#include #include #include #include @@ -54,13 +54,13 @@ inline Expr BroadcastTo(Expr e, int lanes) { // // The same principle applies when using one thread to simulate multiple context. // -class VecAllocAccess : public IRMutator { +class VecAllocAccess : public StmtExprMutator { public: VecAllocAccess(const Variable* buf, Var var, int var_lanes) : buf_(buf), var_(var), var_lanes_(var_lanes) {} // Load - Expr Mutate_(const Load* op, const Expr& e) final { - Expr expr = IRMutator::Mutate_(op, e); + Expr VisitExpr_(const Load* op) final { + Expr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (op->buffer_var.get() == buf_) { return Load::make(op->dtype, op->buffer_var, @@ -71,8 +71,8 @@ class VecAllocAccess : public IRMutator { } } // Store - Stmt Mutate_(const Store* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const Store* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (op->buffer_var.get() == buf_) { return Store::make(op->buffer_var, @@ -93,19 +93,16 @@ class VecAllocAccess : public IRMutator { int var_lanes_; }; -class Vectorizer : public IRMutator { +class Vectorizer : public StmtExprMutator { public: Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) { ramp_ = Ramp::make(0, 1, var_lanes); } - // user mutate from parent. - using IRMutator::Mutate; - Stmt Mutate(Stmt stmt) final { + Stmt VisitStmt(const Stmt& stmt) final { CHECK(!need_scalarize_); - - Stmt ret = IRMutator::Mutate(stmt); + Stmt ret = StmtExprMutator::VisitStmt(stmt); if (need_scalarize_) { need_scalarize_ = false; return Scalarize(stmt); @@ -114,19 +111,18 @@ class Vectorizer : public IRMutator { } } - - Expr Mutate_(const Add* op, const Expr &e) final { - return AddSubVec(op, e); + Expr VisitExpr_(const Add* op) final { + return AddSubVec(op); } - Expr Mutate_(const Sub* op, const Expr &e) final { - return AddSubVec(op, e); + Expr VisitExpr_(const Sub* op) final { + return AddSubVec(op); } - Expr Mutate_(const Mul* op, const Expr &e) final { - Expr a = this->Mutate(op->a); - Expr b = this->Mutate(op->b); + Expr VisitExpr_(const Mul* op) final { + Expr a = this->VisitExpr(op->a); + Expr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return e; + return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); if (lanes != 1) { @@ -143,53 +139,53 @@ class Vectorizer : public IRMutator { } return Mul::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } - return BinaryVec(op, e); + return BinaryVec(op); } - Expr Mutate_(const Div* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const Div* op) final { + return BinaryVec(op); } - Expr Mutate_(const Mod* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const Mod* op) final { + return BinaryVec(op); } - Expr Mutate_(const FloorDiv* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const FloorDiv* op) final { + return BinaryVec(op); } - Expr Mutate_(const FloorMod* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const FloorMod* op) final { + return BinaryVec(op); } - Expr Mutate_(const Min* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const Min* op) final { + return BinaryVec(op); } - Expr Mutate_(const Max* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const Max* op) final { + return BinaryVec(op); } - Expr Mutate_(const EQ* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const EQ* op) final { + return BinaryVec(op); } - Expr Mutate_(const NE* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const NE* op) final { + return BinaryVec(op); } - Expr Mutate_(const LT* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const LT* op) final { + return BinaryVec(op); } - Expr Mutate_(const LE* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const LE* op) final { + return BinaryVec(op); } - Expr Mutate_(const GT* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const GT* op) final { + return BinaryVec(op); } - Expr Mutate_(const GE* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const GE* op) final { + return BinaryVec(op); } - Expr Mutate_(const And* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const And* op) final { + return BinaryVec(op); } - Expr Mutate_(const Or* op, const Expr &e) final { - return BinaryVec(op, e); + Expr VisitExpr_(const Or* op) final { + return BinaryVec(op); } - Expr Mutate_(const Ramp* op, const Expr &e) final { - Expr base = this->Mutate(op->base); - Expr stride = this->Mutate(op->stride); + Expr VisitExpr_(const Ramp* op) final { + Expr base = this->VisitExpr(op->base); + Expr stride = this->VisitExpr(op->stride); if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) { const Ramp* base_ramp = base.as(); if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op->lanes))) { @@ -208,14 +204,14 @@ class Vectorizer : public IRMutator { } return Shuffle::make_concat(elems); } - Expr Mutate_(const Select *op, const Expr& e) final { - Expr cond = this->Mutate(op->condition); - Expr t = this->Mutate(op->true_value); - Expr f = this->Mutate(op->false_value); + Expr VisitExpr_(const Select *op) final { + Expr cond = this->VisitExpr(op->condition); + Expr t = this->VisitExpr(op->true_value); + Expr f = this->VisitExpr(op->false_value); if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { - return e; + return GetRef(op); } else { int lanes = std::max(std::max( cond.dtype().lanes(), @@ -223,37 +219,37 @@ class Vectorizer : public IRMutator { return Select::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes)); } } - Expr Mutate_(const Cast *op, const Expr& e) final { - Expr value = this->Mutate(op->value); + Expr VisitExpr_(const Cast *op) final { + Expr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { - return e; + return GetRef(op); } else { return Cast::make(op->dtype.with_lanes(value.dtype().lanes()), value); } } // Variable - Expr Mutate_(const Variable* v, const Expr& e) final { + Expr VisitExpr_(const Variable* v) final { if (v == var_.get()) { return ramp_; } else if (lets_.count(v)) { return lets_[v]; } else { - return e; + return GetRef(v); } } // IfThenElse expr - Expr MutateIfThenElseExpr_(const Call *op, const Expr& e) { - Expr cond = this->Mutate(op->args[0]); + Expr MutateIfThenElseExpr_(const Call *op) { + Expr cond = this->VisitExpr(op->args[0]); if (cond.dtype().is_vector()) { need_scalarize_ = true; - return e; + return GetRef(op); } - Expr t = this->Mutate(op->args[1]); - Expr f = this->Mutate(op->args[2]); + Expr t = this->VisitExpr(op->args[1]); + Expr f = this->VisitExpr(op->args[2]); if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) { - return e; + return GetRef(op); } else { int lanes = std::max(t.dtype().lanes(), f.dtype().lanes()); t = BroadcastTo(t, lanes); @@ -264,23 +260,23 @@ class Vectorizer : public IRMutator { } } // Call - Expr Mutate_(const Call* op, const Expr& e) final { + Expr VisitExpr_(const Call* op) final { if (op->name == intrinsic::tvm_if_then_else) { - return MutateIfThenElseExpr_(op, e); + return MutateIfThenElseExpr_(op); } if (!op->is_vectorizable()) { // Cannot vectorize this op Array new_args; for (auto arg : op->args) { - auto new_arg = this->Mutate(arg); + auto new_arg = this->VisitExpr(arg); if (new_arg.dtype().is_vector()) { need_scalarize_ = true; - return e; + return GetRef(op); } new_args.push_back(new_arg); } if (op->args.same_as(new_args)) { - return e; + return GetRef(op); } else { return Call::make( op->dtype, op->name, new_args, op->call_type, op->func, op->value_index); @@ -290,7 +286,7 @@ class Vectorizer : public IRMutator { Array new_args = MutateArray(op->args, &lane); // normal code path. if (op->args.same_as(new_args)) { - return e; + return GetRef(op); } else { return Call::make( op->dtype.with_lanes(lane), op->name, new_args, @@ -299,11 +295,11 @@ class Vectorizer : public IRMutator { } } // Load - Expr Mutate_(const Load* op, const Expr& e) final { - Expr index = this->Mutate(op->index); - Expr pred = this->Mutate(op->predicate); + Expr VisitExpr_(const Load* op) final { + Expr index = this->VisitExpr(op->index); + Expr pred = this->VisitExpr(op->predicate); if (index.same_as(op->index) && pred.same_as(op->predicate)) { - return e; + return GetRef(op); } else { int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes()); return Load::make( @@ -314,42 +310,42 @@ class Vectorizer : public IRMutator { } } // Let - Expr Mutate_(const Let* op, const Expr& e) final { - Expr value = this->Mutate(op->value); + Expr VisitExpr_(const Let* op) final { + Expr value = this->VisitExpr(op->value); CHECK(!lets_.count(op->var.get())) << "not SSA"; if (value.dtype().lanes() != op->value.dtype().lanes()) { Var v(op->var->name_hint, value.dtype()); lets_[op->var.get()] = v; - return Let::make(v, value, Mutate(op->body)); + return Let::make(v, value, this->VisitExpr(op->body)); } else { - Expr body = this->Mutate(op->body); + Expr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { - return e; + return GetRef(op); } else { return Let::make(op->var, value, body); } } } // Provide - Stmt Mutate_(const Provide* op, const Stmt& s) final { - Expr new_value = this->Mutate(op->value); + Stmt VisitStmt_(const Provide* op) final { + Expr new_value = this->VisitExpr(op->value); int lane = new_value.dtype().lanes(); Array new_args = MutateArray(op->args, &lane); if (op->args.same_as(new_args) && op->value.same_as(new_value)) { - return s; + return GetRef(op); } else { new_value = BroadcastTo(new_value, lane); return Provide::make(op->func, op->value_index, new_value, new_args); } } // Store - Stmt Mutate_(const Store* op, const Stmt& s) final { - Expr value = this->Mutate(op->value); - Expr index = this->Mutate(op->index); - Expr pred = this->Mutate(op->predicate); + Stmt VisitStmt_(const Store* op) final { + Expr value = this->VisitExpr(op->value); + Expr index = this->VisitExpr(op->index); + Expr pred = this->VisitExpr(op->predicate); if (value.same_as(op->value) && index.same_as(op->index)) { - return s; + return GetRef(op); } else { int lanes = std::max(value.dtype().lanes(), index.dtype().lanes()); lanes = std::max(lanes, pred.dtype().lanes()); @@ -360,20 +356,20 @@ class Vectorizer : public IRMutator { } } // For - Stmt Mutate_(const For* op, const Stmt& s) final { + Stmt VisitStmt_(const For* op) final { if (op->for_type == ForType::Vectorized) { LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring..."; } CHECK(is_zero(op->min)); CHECK(!op->extent.dtype().is_vector()); - Expr extent = Mutate(op->extent); + Expr extent = this->VisitExpr(op->extent); if (extent.dtype().is_vector()) { - return Scalarize(s); + return Scalarize(GetRef(op)); } - Stmt body = Mutate(op->body); + Stmt body = this->VisitStmt(op->body); if (extent.same_as(op->extent) && body.same_as(op->body)) { - return s; + return GetRef(op); } else { return For::make( op->loop_var, op->min, extent, @@ -381,47 +377,47 @@ class Vectorizer : public IRMutator { } } // IfThenElse - Stmt Mutate_(const IfThenElse* op, const Stmt& s) final { + Stmt VisitStmt_(const IfThenElse* op) final { CHECK(!op->condition.dtype().is_vector()); - Expr condition = this->Mutate(op->condition); + Expr condition = this->VisitExpr(op->condition); if (condition.dtype().is_vector()) { - return Scalarize(s); + return Scalarize(GetRef(op)); } - Stmt then_case = this->Mutate(op->then_case); + Stmt then_case = this->VisitStmt(op->then_case); Stmt else_case; if (op->else_case.defined()) { - else_case = this->Mutate(op->else_case); + else_case = this->VisitStmt(op->else_case); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { - return s; + return GetRef(op); } else { return IfThenElse::make(condition, then_case, else_case); } } // LetStmt - Stmt Mutate_(const LetStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const LetStmt* op) final { LOG(WARNING) << "Cannot vectorize with LetStmt, remove it with Simplify Before Vectorize"; - return Scalarize(s); + return Scalarize(GetRef(op)); } // Allocate - Stmt Mutate_(const Allocate* op, const Stmt& s) final { + Stmt VisitStmt_(const Allocate* op) final { if (op->new_expr.defined()) { LOG(WARNING) << "Cannot vectorize with new expr"; - return Scalarize(s); + return Scalarize(GetRef(op)); } - Expr condition = Mutate(op->condition); + Expr condition = this->VisitExpr(op->condition); if (condition.dtype().is_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc "; - return Scalarize(s); + return Scalarize(GetRef(op)); } Array extents; for (size_t i = 0; i < op->extents.size(); i++) { - Expr new_ext = Mutate(op->extents[i]); + Expr new_ext = this->VisitExpr(op->extents[i]); if (new_ext.dtype().is_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc "; - return Scalarize(s); + return Scalarize(GetRef(op)); } extents.push_back(new_ext); } @@ -429,8 +425,8 @@ class Vectorizer : public IRMutator { extents.push_back(var_lanes_); // rewrite access to buffer internally. Stmt body = VecAllocAccess( - op->buffer_var.get(), var_, var_lanes_).Mutate(op->body); - body = Mutate(body); + op->buffer_var.get(), var_, var_lanes_)(op->body); + body = this->VisitStmt(body); return Allocate::make( op->buffer_var, op->dtype, extents, condition, body, @@ -466,7 +462,7 @@ class Vectorizer : public IRMutator { std::vector new_arr(arr.size()); for (size_t i = 0; i < arr.size(); i++) { Expr old_elem = arr[i]; - Expr new_elem = this->Mutate(old_elem); + Expr new_elem = this->VisitExpr(old_elem); if (!new_elem.same_as(old_elem)) changed = true; new_arr[i] = new_elem; lanes = std::max(lanes, new_elem.dtype().lanes()); @@ -482,24 +478,24 @@ class Vectorizer : public IRMutator { return Array(new_arr); } template - Expr BinaryVec(const T* op, const Expr& e) { - Expr a = this->Mutate(op->a); - Expr b = this->Mutate(op->b); + Expr BinaryVec(const T* op) { + Expr a = this->VisitExpr(op->a); + Expr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return e; + return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } } template - Expr AddSubVec(const T* op, const Expr& e) { - Expr a = this->Mutate(op->a); - Expr b = this->Mutate(op->b); + Expr AddSubVec(const T* op) { + Expr a = this->VisitExpr(op->a); + Expr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { - return e; + return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); if (lanes != 1) { @@ -521,9 +517,9 @@ class Vectorizer : public IRMutator { } }; -class LoopVectorizer : public IRMutator { +class LoopVectorizer : public StmtMutator { public: - Stmt Mutate_(const For* op, const Stmt& s) final { + Stmt VisitStmt_(const For* op) final { if (op->for_type == ForType::Vectorized) { CHECK(is_zero(op->min)); int lanes = 0; @@ -531,21 +527,21 @@ class LoopVectorizer : public IRMutator { if (!succ || lanes < 1) { LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent; } - return Vectorizer(op->loop_var, lanes).Mutate(op->body); + return Vectorizer(op->loop_var, lanes)(op->body); } else { - return IRMutator::Mutate_(op, s); + return StmtMutator::VisitStmt_(op); } } }; Stmt VectorizeLoop(Stmt stmt) { - return LoopVectorizer().Mutate(stmt); + return LoopVectorizer()(std::move(stmt)); } -class VectorizeSkipper : public IRMutator { +class VectorizeSkipper : public StmtMutator { public: - Stmt Mutate_(const For* op, const Stmt& s) final { - Stmt stmt = IRMutator::Mutate_(op, s); + Stmt VisitStmt_(const For* op) final { + Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); if (op->for_type == ForType::Vectorized) { return For::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, @@ -557,7 +553,7 @@ class VectorizeSkipper : public IRMutator { }; Stmt SkipVectorize(Stmt stmt) { - return VectorizeSkipper().Mutate(stmt); + return VectorizeSkipper()(std::move(stmt)); } } // namespace ir From 4c2c34892893b589754e4ff8b3a2f115384517d0 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 11:27:18 -0800 Subject: [PATCH 31/45] Migrate verify compact_buffer gpu_code --- src/pass/verify_compact_buffer.cc | 10 +++++----- src/pass/verify_gpu_code.cc | 24 ++++++++++++------------ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/pass/verify_compact_buffer.cc b/src/pass/verify_compact_buffer.cc index c2131f6c7687..671b4a073117 100644 --- a/src/pass/verify_compact_buffer.cc +++ b/src/pass/verify_compact_buffer.cc @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include @@ -32,15 +32,15 @@ namespace tvm { namespace ir { -class VerifyBuffer : public IRVisitor { +class VerifyBuffer : public StmtVisitor { public: bool Verify(const Stmt& stmt) { - this->Visit(stmt); + this->VisitStmt(stmt); return is_compact_; } - void Visit_(const AttrStmt* op) final { - IRVisitor::Visit_(op); + void VisitStmt_(const AttrStmt* op) final { + StmtVisitor::VisitStmt_(op); if (op->attr_key == attr::buffer_bind_scope) { is_compact_ = true; } diff --git a/src/pass/verify_gpu_code.cc b/src/pass/verify_gpu_code.cc index 49a05345a99f..36871679a681 100644 --- a/src/pass/verify_gpu_code.cc +++ b/src/pass/verify_gpu_code.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -26,12 +26,12 @@ #include #include -#include +#include namespace tvm { namespace ir { -class GPUCodeVerifier : public IRVisitor { +class GPUCodeVerifier : public StmtVisitor { public: bool Verify(tvm::Stmt stmt, int64_t max_local_memory_per_block, @@ -49,12 +49,12 @@ class GPUCodeVerifier : public IRVisitor { Reset_(); - this->Visit(stmt); + this->VisitStmt(stmt); return valid_; } - void Visit_(const ProducerConsumer *op) { + void VisitStmt_(const ProducerConsumer *op) { if (nest_level_ == 0) { // enter a new kernel, reset statistics Reset_(); @@ -62,10 +62,10 @@ class GPUCodeVerifier : public IRVisitor { if (op->is_producer) { nest_level_++; - IRVisitor::Visit_(op); + StmtVisitor::VisitStmt_(op); nest_level_--; } else { - IRVisitor::Visit_(op); + StmtVisitor::VisitStmt_(op); } if (nest_level_ == 0) { @@ -77,8 +77,8 @@ class GPUCodeVerifier : public IRVisitor { } } - void Visit_(const Allocate *op) { - IRVisitor::Visit_(op); + void VisitStmt_(const Allocate *op) { + StmtVisitor::VisitStmt_(op); // visit an allocation of a buffer in shared memory, record its size if (visited_local_buffers_.count(op->buffer_var.get()) != 0) { size_t size = static_cast(op->constant_allocation_size()); @@ -89,7 +89,7 @@ class GPUCodeVerifier : public IRVisitor { } } - void Visit_(const AttrStmt *op) { + void VisitStmt_(const AttrStmt *op) { if (op->attr_key == attr::storage_scope) { std::string op_value = op->value.as()->value; if (op_value == "local") { @@ -132,7 +132,7 @@ class GPUCodeVerifier : public IRVisitor { } } } - IRVisitor::Visit_(op); + StmtVisitor::VisitStmt_(op); } private: From ff0e27a1c76654b52ff5baa6e2867cdd26db737d Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 11:30:00 -0800 Subject: [PATCH 32/45] Migrate verify_memory --- src/pass/verify_memory.cc | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/src/pass/verify_memory.cc b/src/pass/verify_memory.cc index 4a5c8adeb8e7..6e5e02d3d79b 100644 --- a/src/pass/verify_memory.cc +++ b/src/pass/verify_memory.cc @@ -22,8 +22,9 @@ * \brief Pass to check if memory accesses are legal. */ #include -#include #include +#include + namespace tvm { namespace ir { @@ -39,7 +40,7 @@ namespace { * This pass performs such verification by checking if all Producer/Consumer * with memory accesses are bound with threads when device type is GPU. */ -class MemoryAccessVerifier final : protected IRVisitor { +class MemoryAccessVerifier final : protected StmtExprVisitor { public: /// Special member functions //@{ @@ -55,7 +56,7 @@ class MemoryAccessVerifier final : protected IRVisitor { /// Interface to perform memory access verification void Run() { if (!IsGPUDevice(dev_type_) && !IsFPGADevice(dev_type_)) return; - IRVisitor::Visit(func_->body); + StmtExprVisitor::VisitStmt(func_->body); } /// Verification result @@ -64,42 +65,47 @@ class MemoryAccessVerifier final : protected IRVisitor { protected: /// Visitor implementation //@{ - void Visit(const ObjectRef &n) final { + void VisitExpr(const Expr &n) final { + if (Failed()) return; + StmtExprVisitor::VisitExpr(n); + } + + void VisitStmt(const Stmt &n) final { if (Failed()) return; - IRVisitor::Visit(n); + StmtExprVisitor::VisitStmt(n); } - void Visit_(const LetStmt *op) final { + void VisitStmt_(const LetStmt *op) final { // Book keep definitions defs_[op->var.get()] = op->value; - return IRVisitor::Visit_(op); + return StmtExprVisitor::VisitStmt_(op); } - void Visit_(const AttrStmt *op) final { + void VisitStmt_(const AttrStmt *op) final { if (!InThreadEnv() && (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope)) { EnterThreadEnv(); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); ExitThreadEnv(); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } } - void Visit_(const ProducerConsumer *op) final { + void VisitStmt_(const ProducerConsumer *op) final { EnterProducerConsumer(op); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); ExitProducerConsumer(); } - void Visit_(const Load *op) final { + void VisitExpr_(const Load *op) final { HandleLoadStoreToVariable(op->buffer_var); - return IRVisitor::Visit_(op); + return StmtExprVisitor::VisitExpr_(op); } - void Visit_(const Store *op) final { + void VisitStmt_(const Store *op) final { HandleLoadStoreToVariable(op->buffer_var); - return IRVisitor::Visit_(op); + return StmtExprVisitor::VisitStmt_(op); } //@} From c38e155f4f9a8ea3b3fbd7ea045748604e7560bd Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 11:35:15 -0800 Subject: [PATCH 33/45] Migrate storage_sync --- src/pass/storage_sync.cc | 39 +++++++++++++++++++-------------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc index dd6d2a1b7ee0..6ace4f7f85b4 100644 --- a/src/pass/storage_sync.cc +++ b/src/pass/storage_sync.cc @@ -22,8 +22,7 @@ */ #include #include -#include -#include +#include #include #include #include "ir_util.h" @@ -197,13 +196,13 @@ class ThreadSyncPlanner : public StorageAccessVisitor { StorageScope sync_scope_; }; -class ThreadSyncInserter : public IRMutator { +class ThreadSyncInserter : public StmtExprMutator { public: ThreadSyncInserter(StorageScope sync_scope, const std::unordered_set& syncs) : sync_scope_(sync_scope), syncs_(syncs) {} - Stmt Mutate(Stmt stmt) final { + Stmt VisitStmt(const Stmt& stmt) final { if (syncs_.size() == 0) return stmt; if (syncs_.count(stmt.get())) { Stmt barrier; @@ -216,33 +215,33 @@ class ThreadSyncInserter : public IRMutator { Call::Intrinsic)); } // Mutate after query, to avoid stmt change. - stmt = IRMutator::Mutate(stmt); - stmt = Block::make(barrier, stmt); + auto ret = StmtExprMutator::VisitStmt(stmt); + ret = Block::make(barrier, ret); + return ret; } else { - stmt = IRMutator::Mutate(stmt); + return StmtExprMutator::VisitStmt(stmt); } - return stmt; } - Expr Mutate_(const Load* op, const Expr& e) final { + Expr VisitExpr_(const Load* op) final { if (sync_scope_.rank == StorageRank::kGlobal && GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].read_count; } - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } - Stmt Mutate_(const Store* op, const Stmt& s) final { + Stmt VisitStmt_(const Store* op) final { if (sync_scope_.rank == StorageRank::kGlobal && GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].write_count; } - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::thread_extent) { bool temp = true; std::swap(temp, in_thread_env_); thread_extents_.push_back(op); - Stmt ret = IRMutator::Mutate_(op, s); + Stmt ret = StmtExprMutator::VisitStmt_(op); thread_extents_.pop_back(); std::swap(temp, in_thread_env_); // first thread scope. @@ -256,15 +255,15 @@ class ThreadSyncInserter : public IRMutator { const Variable* buf = op->node.as(); storage_scope_[buf] = StorageScope::make(op->value.as()->value); - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Expr Mutate_(const Call* op, const Expr& e) final { + Expr VisitExpr_(const Call* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { - Expr expr = IRMutator::Mutate_(op, e); + Expr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); CHECK_EQ(op->args.size(), 5U); const Variable* buffer_var = op->args[1].as(); @@ -280,7 +279,7 @@ class ThreadSyncInserter : public IRMutator { } return expr; } else { - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } } @@ -364,7 +363,7 @@ Stmt ThreadSync(Stmt stmt, std::string storage_scope) { StorageScope sync_scope = StorageScope::make(storage_scope); ThreadSyncPlanner planner(sync_scope); planner(stmt); - return ThreadSyncInserter(sync_scope, planner.syncs_inserted_).Mutate(stmt); + return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt)); } LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) { From 9a1b932a3ffb384d93d62170debb50d66f35be24 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 11:35:33 -0800 Subject: [PATCH 34/45] Remove unused refs to mutator --- src/arithmetic/canonical_simplify.cc | 1 - src/arithmetic/const_fold.h | 1 - src/arithmetic/rewrite_simplify.cc | 1 - src/arithmetic/rewrite_simplify.h | 1 - src/arithmetic/stmt_simplify.cc | 1 - 5 files changed, 5 deletions(-) diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 9ad69a61ed5f..d05ee2dd9a30 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -23,7 +23,6 @@ */ #include #include -#include #include "const_fold.h" #include "pattern_match.h" #include "rewrite_simplify.h" diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index 93bf708a113f..8b4ea2fa8133 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -25,7 +25,6 @@ #define TVM_ARITHMETIC_CONST_FOLD_H_ #include -#include #include #include #include diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 60cc63ff2b96..f883bf145f59 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -24,7 +24,6 @@ // Acknowledgement: Most rewrite-rules are from Halide. #include #include -#include #include #include "const_fold.h" #include "pattern_match.h" diff --git a/src/arithmetic/rewrite_simplify.h b/src/arithmetic/rewrite_simplify.h index bce43658d68f..cf9dd6edbefa 100644 --- a/src/arithmetic/rewrite_simplify.h +++ b/src/arithmetic/rewrite_simplify.h @@ -26,7 +26,6 @@ #include #include -#include #include #include #include "const_fold.h" diff --git a/src/arithmetic/stmt_simplify.cc b/src/arithmetic/stmt_simplify.cc index 1e58f9beaa65..4996cfd2628f 100644 --- a/src/arithmetic/stmt_simplify.cc +++ b/src/arithmetic/stmt_simplify.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include "ir_mutator_with_analyzer.h" From 0dbabebd1e6f6fb6efa3c53285183e6c6c4201b6 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 12:39:37 -0800 Subject: [PATCH 35/45] Migrate hybrid_op --- src/op/hybrid_op.cc | 48 ++++++++++++++++++++++----------------------- src/op/hybrid_op.h | 6 ++---- src/op/op_util.cc | 14 ++++++------- 3 files changed, 33 insertions(+), 35 deletions(-) diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc index fd64b93a032a..f0bc200732fd 100644 --- a/src/op/hybrid_op.cc +++ b/src/op/hybrid_op.cc @@ -221,7 +221,7 @@ namespace op { Stmt ApplyLoopShapes(const Stage &stage, const std::unordered_map &dom_map, Stmt stmt) { - class LoopSpliter : public IRMutator { + class LoopSpliter : public StmtExprMutator { Expr factor; const Variable *parent; IterVar inner, outer; @@ -247,7 +247,7 @@ Stmt ApplyLoopShapes(const Stage &stage, outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type); } - Stmt Mutate_(const For *op, const Stmt &stmt) { + Stmt VisitStmt_(const For *op) final { if (op->loop_var.get() == parent) { std::unordered_map rmap; rmap[op->loop_var.get()] = inner + outer * factor; @@ -261,11 +261,11 @@ Stmt ApplyLoopShapes(const Stage &stage, splitted = true; return ret; } - return IRMutator::Mutate_(op, stmt); + return StmtExprMutator::VisitStmt_(op); } }; - class LoopFuser : public IRMutator { + class LoopFuser : public StmtExprMutator { const IterVar &parent; const Variable *inner; const Variable *outer; @@ -281,7 +281,7 @@ Stmt ApplyLoopShapes(const Stage &stage, // TODO(@were): Handle imperfect loops - Stmt Mutate_(const For *op, const Stmt &stmt) { + Stmt VisitStmt_(const For *op) { if (op->loop_var.get() == inner) { CHECK(under_outer); std::unordered_map rmap; @@ -291,7 +291,7 @@ Stmt ApplyLoopShapes(const Stage &stage, return ir::Substitute(op->body, rmap); } else if (op->loop_var.get() == outer) { under_outer = true; - Stmt body = IRMutator::Mutate(op->body); + Stmt body = this->VisitStmt(op->body); std::unordered_map rmap; rmap[op->loop_var.get()] = indexdiv(parent, extent); body = ir::Substitute(body, rmap); @@ -299,25 +299,25 @@ Stmt ApplyLoopShapes(const Stage &stage, return For::make(parent->var, Expr(0), extent * op->extent, op->for_type, op->device_api, body); } else if (under_outer) { - Stmt body = IRMutator::Mutate(op->body); + Stmt body = this->VisitStmt(op->body); std::unordered_map rmap; rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent); body = ir::Substitute(body, rmap); extent = extent * op->extent; return body; } - return IRMutator::Mutate(stmt); + return StmtExprMutator::VisitStmt_(op); } }; for (auto &rel : stage->relations) { if (const SplitNode *split = rel.as()) { LoopSpliter Spliter(split, dom_map); - stmt = Spliter.Mutate(stmt); + stmt = Spliter(stmt); CHECK(Spliter.splitted); } else if (const FuseNode *fuse = rel.as()) { LoopFuser Fuser(fuse); - stmt = Fuser.Mutate(stmt); + stmt = Fuser(stmt); CHECK(Fuser.fused); } } @@ -327,14 +327,14 @@ Stmt ApplyLoopShapes(const Stage &stage, Stmt ApplyLoopAnnotations(const Stage &stage, const std::unordered_map &rebased, Stmt stmt) { - class LoopAnnotator : public IRMutator { + class LoopAnnotator : public StmtMutator { const Variable *var; const IterVarAttr &attr; public: LoopAnnotator(const Variable *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {} - Stmt Mutate_(const For *op, const Stmt &stmt) { + Stmt VisitStmt_(const For *op) final { if (op->loop_var.get() == var) { if (attr->bind_thread.defined()) { const auto &iter_var = attr->bind_thread; @@ -352,7 +352,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage, IterVarTypeToForType(attr->iter_type), op->device_api, op->body); } } - return IRMutator::Mutate_(op, stmt); + return StmtMutator::VisitStmt_(op); } }; @@ -381,7 +381,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage, CHECK_EQ(found, 1) << " iter var should be found exactly once!"; if (need_change) { - stmt = LoopAnnotator(var, attr).Mutate(stmt); + stmt = LoopAnnotator(var, attr)(std::move(stmt)); } } return stmt; @@ -411,7 +411,7 @@ Stmt ApplyLoopOrder(const Stage &stage, } } - class LoopReorder : public IRMutator { + class LoopReorder : public StmtMutator { const Stage &stage; const std::unordered_map &dom_map; const std::unordered_map &reorder; @@ -422,13 +422,13 @@ Stmt ApplyLoopOrder(const Stage &stage, const std::unordered_map &reorder) : stage(stage), dom_map(dom_map), reorder(reorder) {} - Stmt Mutate_(const For *op, const Stmt &stmt) { + Stmt VisitStmt_(const For *op) final { // Reorder from in to out - Stmt body_ = IRMutator::Mutate(op->body); + Stmt body_ = this->VisitStmt(op->body); CHECK(reorder.count(op->loop_var.get())); auto target = reorder.find(op->loop_var.get())->second; if (body_.same_as(op->body) && op->loop_var.get() == target->var.get()) - return stmt; + return GetRef(op); const Stmt &body = op->body.same_as(body_) ? op->body : body_; ForType for_type = IterVarTypeToForType(target->iter_type); if (stage->iter_var_attrs.count(target)) { @@ -441,7 +441,7 @@ Stmt ApplyLoopOrder(const Stage &stage, }; if (need_reorder) - return LoopReorder(stage, dom_map, reorder).Mutate(stmt); + return LoopReorder(stage, dom_map, reorder)(stmt); return stmt; } @@ -479,21 +479,21 @@ std::vector GatherLoopVars(Stmt stmt) { } // replacer to replace tensors' usage in Provide -class ProviderReplacer : public ir::IRMutator { +class ProviderReplacer : public ir::StmtMutator { public: explicit ProviderReplacer(const std::unordered_map &vmap) : vmap_(vmap) {} - Stmt Mutate_(const ir::Provide* op, const Stmt &s) { + Stmt VisitStmt_(const ir::Provide* op) final { Tensor t = Downcast(op->func).output(op->value_index); auto it = vmap_.find(t); if (it != vmap_.end()) { Stmt ret = ir::Provide::make( it->second->op, it->second->value_index, op->value, op->args); found = true; - return IRMutator::Mutate_(ret.as(), ret); + return this->VisitStmt(ret); } - return IRMutator::Mutate_(op, s); + return StmtMutator::VisitStmt_(op); } // whether it is found. @@ -506,7 +506,7 @@ class ProviderReplacer : public ir::IRMutator { Stmt ReplaceProvideTensor(Stmt stmt, const std::unordered_map &replace) { ProviderReplacer repl(replace); - Stmt ret = repl.Mutate(stmt); + Stmt ret = repl(stmt); return repl.found ? ret : stmt; } } // namespace op diff --git a/src/op/hybrid_op.h b/src/op/hybrid_op.h index 3e7b8a2ea764..f180129c263c 100644 --- a/src/op/hybrid_op.h +++ b/src/op/hybrid_op.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,8 +25,6 @@ #define TVM_OP_HYBRID_OP_H_ #include -#include -#include #include #include #include diff --git a/src/op/op_util.cc b/src/op/op_util.cc index cd3b168d810b..801538ddf914 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -23,8 +23,8 @@ */ #include #include +#include #include -#include #include #include "op_util.h" #include "../schedule/message_passing.h" @@ -186,12 +186,12 @@ std::vector MakeIfNest(const std::vector& predicates) { } // replacer to replace tensors -class TensorReplacer : public ir::IRMutator { +class TensorReplacer : public ir::StmtExprMutator { public: explicit TensorReplacer(const std::unordered_map& vmap) : vmap_(vmap) {} - Expr Mutate_(const ir::Call* op, const Expr& e) { + Expr VisitExpr_(const ir::Call* op) { if (op->call_type == ir::Call::Halide) { Tensor t = Downcast(op->func).output(op->value_index); auto it = vmap_.find(t); @@ -200,10 +200,10 @@ class TensorReplacer : public ir::IRMutator { op->dtype, it->second->op->name, op->args, op->call_type, it->second->op, it->second->value_index); found = true; - return IRMutator::Mutate_(ret.as(), ret); + return this->VisitExpr(ret); } } - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } // whether it is found. @@ -216,13 +216,13 @@ class TensorReplacer : public ir::IRMutator { Stmt ReplaceTensor(Stmt stmt, const std::unordered_map& replace) { TensorReplacer repl(replace); - Stmt ret = repl.Mutate(stmt); + Stmt ret = repl(stmt); return repl.found ? ret : stmt; } Expr ReplaceTensor(Expr expr, const std::unordered_map& replace) { TensorReplacer repl(replace); - Expr ret = repl.Mutate(expr); + Expr ret = repl(expr); return repl.found ? ret : expr; } From d3b980a0fcc853a20e2b1375280152c3aa8677de Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 12:47:13 -0800 Subject: [PATCH 36/45] Migrate tensorize --- src/op/tensorize.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index 7ab54e983028..dba9ca092239 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -22,7 +22,7 @@ * \file tensorize.cc */ #include -#include +#include #include #include #include "op_util.h" @@ -157,10 +157,10 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, } // Remap the tensor placeholder, index and inline things. -class TensorIntrinMatcher final : public IRMutator { +class TensorIntrinMatcher final : public StmtExprMutator { public: - Expr Mutate_(const Call* op, const Expr& e) final { - Expr expr = IRMutator::Mutate_(op, e); + Expr VisitExpr_(const Call* op) final { + Expr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (op->call_type == Call::Halide) { Tensor t = Downcast(op->func).output(op->value_index); @@ -180,17 +180,17 @@ class TensorIntrinMatcher final : public IRMutator { return expr; } - Expr Mutate_(const Variable* op, const Expr& e) final { + Expr VisitExpr_(const Variable* op) final { auto it = var_remap_.find(op); if (it != var_remap_.end()) { return it->second; } else { - return e; + return GetRef(op); } } - Expr Mutate_(const Reduce* op, const Expr& e) final { - Expr expr = IRMutator::Mutate_(op, e); + Expr VisitExpr_(const Reduce* op) final { + Expr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); Array axis; for (size_t i = 0; i < op->axis.size(); ++i) { @@ -317,7 +317,7 @@ Array MatchTensorizeBody( matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space); Array ret; for (Expr expr : self->body) { - ret.push_back(matcher.Mutate(expr)); + ret.push_back(matcher(expr)); } return ret; } From 614503659fb07ff7e8ad736b6f6841b6fc88a1c5 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 12:52:44 -0800 Subject: [PATCH 37/45] Migrate schedule ops --- src/schedule/schedule_lang.cc | 1 - src/schedule/schedule_ops.cc | 91 +++++++++++++++++------------------ 2 files changed, 45 insertions(+), 47 deletions(-) diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index ec73c67bedff..91d3726f0bab 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -22,7 +22,6 @@ */ #include #include -#include #include #include "graph.h" diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index 0103410e6132..b177d6f8d22f 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -21,9 +21,8 @@ * \file schedule_ops.cc */ #include -#include #include -#include +#include #include #include #include @@ -71,7 +70,7 @@ Stmt MakePipeline(const Stage& s, } // inject the operator's realization on the stmt. -class InjectAttach : public IRMutator { +class InjectAttach : public StmtMutator { public: InjectAttach(const Stage& stage, const Stage& attach_spec, @@ -80,9 +79,9 @@ class InjectAttach : public IRMutator { : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map), debug_keep_trivial_loop_(debug_keep_trivial_loop) {} - Stmt Mutate(Stmt stmt) final { - CHECK(stmt.defined()); - stmt = IRMutator::Mutate(stmt); + Stmt VisitStmt(const Stmt& input_stmt) final { + CHECK(input_stmt.defined()); + auto stmt = StmtMutator::VisitStmt(input_stmt); const AttrStmt* op = stmt.as(); if (op != nullptr && op->attr_key == attr::loop_scope) { @@ -115,7 +114,7 @@ class InjectAttach : public IRMutator { }; // inject the operator's realization on the stmt. -class InjectScanStep : public IRMutator { +class InjectScanStep : public StmtMutator { public: InjectScanStep(const Stage& stage, const Operation& scan_op, @@ -125,9 +124,9 @@ class InjectScanStep : public IRMutator { : stage_(stage), scan_op_(scan_op), dom_map_(dom_map), is_init_(is_init), debug_keep_trivial_loop_(debug_keep_trivial_loop) {} - Stmt Mutate(Stmt stmt) final { - CHECK(stmt.defined()); - stmt = IRMutator::Mutate(stmt); + Stmt VisitStmt(const Stmt& input_stmt) final { + CHECK(input_stmt.defined()); + auto stmt = StmtMutator::VisitStmt(input_stmt); // update const AttrStmt* op = stmt.as(); if (op != nullptr && @@ -161,12 +160,12 @@ class InjectScanStep : public IRMutator { // Postprocessing of schedule op // Replace the init and update's expression by scan's buffer. -class SchedulePostProc : public IRMutator { +class SchedulePostProc : public StmtExprMutator { public: - Stmt Mutate_(const ProducerConsumer* op, const Stmt& s) final { + Stmt VisitStmt_(const ProducerConsumer* op) final { auto it = replace_op_.find(op->func.get()); if (it != replace_op_.end()) { - Stmt body = this->Mutate(op->body); + Stmt body = this->VisitStmt(op->body); if (it->second.defined()) { return ProducerConsumer::make( it->second, op->is_producer, body); @@ -174,36 +173,36 @@ class SchedulePostProc : public IRMutator { return body; } } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const LetStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const LetStmt* op) final { if (!HasSideEffect(op->value)) { - var_value_[op->var.get()] = Mutate(op->value); - return this->Mutate(op->body); + var_value_[op->var.get()] = this->VisitExpr(op->value); + return this->VisitStmt(op->body); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + Stmt VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::loop_scope || op->attr_key == attr::scan_init_scope) { - return this->Mutate(op->body); + return this->VisitStmt(op->body); } else if (op->attr_key == attr::scan_update_scope) { const ScanOpNode* scan = op->node.as(); CHECK(scan); var_value_[scan->scan_axis->var.get()] = op->value; - return this->Mutate(op->body); + return this->VisitStmt(op->body); } else if (op->attr_key == attr::thread_extent) { // delete duplicated thread extent attr auto it = thread_extent_scope_.find(op->node.get()); if (it != thread_extent_scope_.end()) { CHECK(is_zero(ir::Simplify(it->second - op->value))); - return this->Mutate(op->body); + return this->VisitStmt(op->body); } else { thread_extent_scope_[op->node.get()] = op->value; - Stmt ret = IRMutator::Mutate_(op, s); + Stmt ret = StmtExprMutator::VisitStmt_(op); thread_extent_scope_.erase(op->node.get()); return ret; } @@ -214,9 +213,9 @@ class SchedulePostProc : public IRMutator { if (it->second.defined()) { Stmt ret = AttrStmt::make( it->second, op->attr_key, op->value, op->body); - return this->Mutate(ret); + return this->VisitStmt(ret); } else { - return this->Mutate(op->body); + return this->VisitStmt(op->body); } } } else if (op->attr_key == ir::attr::buffer_bind_scope) { @@ -227,9 +226,9 @@ class SchedulePostProc : public IRMutator { if (it->second.defined()) { return AttrStmt::make( Array{tuple[0], it->second.output(tensor->value_index)}, - op->attr_key, op->value, Mutate(op->body)); + op->attr_key, op->value, this->VisitStmt(op->body)); } else { - return this->Mutate(op->body); + return this->VisitStmt(op->body); } } } else if (op->attr_key == ir::attr::buffer_dim_align) { @@ -239,16 +238,16 @@ class SchedulePostProc : public IRMutator { if (it->second.defined()) { return AttrStmt::make( it->second.output(tensor->value_index), - op->attr_key, op->value, Mutate(op->body)); + op->attr_key, op->value, this->VisitStmt(op->body)); } else { - return this->Mutate(op->body); + return this->VisitStmt(op->body); } } } - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } - Stmt Mutate_(const Realize* op, const Stmt& s) final { + Stmt VisitStmt_(const Realize* op) final { TensorKey key{op->func, op->value_index}; auto it = replace_realize_.find(key); if (it != replace_realize_.end()) { @@ -256,29 +255,29 @@ class SchedulePostProc : public IRMutator { Stmt ret = Realize::make( it->second->op, it->second->value_index, op->dtype, op->bounds, op->condition, op->body); - return this->Mutate(ret); + return this->VisitStmt(ret); } else { - return this->Mutate(op->body); + return this->VisitStmt(op->body); } } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Stmt Mutate_(const Provide* op, const Stmt& s) final { + Stmt VisitStmt_(const Provide* op) final { TensorKey key{op->func, op->value_index}; auto it = replace_buffer_.find(key); if (it != replace_buffer_.end()) { const Tensor& dst = it->second; Stmt ret = Provide::make( dst->op, dst->value_index, op->value, op->args); - return this->Mutate(ret); + return this->VisitStmt(ret); } else { - return IRMutator::Mutate_(op, s); + return StmtExprMutator::VisitStmt_(op); } } - Expr Mutate_(const Call* op, const Expr& e) final { + Expr VisitExpr_(const Call* op) final { if (op->call_type == Call::Halide) { TensorKey key{op->func, op->value_index}; auto it = replace_buffer_.find(key); @@ -287,18 +286,18 @@ class SchedulePostProc : public IRMutator { Expr ret = Call::make( op->dtype, dst->op->name, op->args, op->call_type, dst->op, dst->value_index); - return this->Mutate(ret); + return this->VisitExpr(ret); } } - return IRMutator::Mutate_(op, e); + return StmtExprMutator::VisitExpr_(op); } - Expr Mutate_(const Variable* op, const Expr& e) final { + Expr VisitExpr_(const Variable* op) final { auto it = var_value_.find(op); if (it != var_value_.end()) { return it->second; } else { - return e; + return GetRef(op); } } @@ -392,14 +391,14 @@ Stmt ScheduleOps( if (scan_init.count(s->op)) { CHECK(body.defined()); InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop); - body = mu.Mutate(body); + body = mu(std::move(body)); CHECK(mu.found_attach) << "did not find attachment point for scan.init"; } else if (attach_spec->attach_type == kScanUpdate) { // Handle scan update CHECK(body.defined()); InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop); - body = mu.Mutate(body); + body = mu(std::move(body)); CHECK(mu.found_attach) << "did not find attachment point for scan.update"; } else if (attach_spec->attach_type == kInlinedAlready) { @@ -411,7 +410,7 @@ Stmt ScheduleOps( CHECK_EQ(attach_spec->attach_type, kScope); CHECK(body.defined()); InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop); - body = mutator.Mutate(body); + body = mutator(std::move(body)); CHECK(mutator.found_attach) << "did not find attachment point for " << s << " in " << attach_spec->attach_stage->op << " x " << attach_spec->attach_ivar @@ -421,7 +420,7 @@ Stmt ScheduleOps( } SchedulePostProc post_proc; post_proc.Init(sch); - return post_proc.Mutate(body); + return post_proc(std::move(body)); } } // namespace schedule From eb39e63a58a62431be6bf447e5112e0db39c04f2 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 13:02:03 -0800 Subject: [PATCH 38/45] Migrate schedule_dataflow_rewrite --- src/schedule/schedule_dataflow_rewrite.cc | 30 +++++++++++------------ 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 70a73abc4698..fe5ea1e21875 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -22,7 +22,7 @@ */ #include #include -#include +#include #include #include #include "message_passing.h" @@ -42,24 +42,24 @@ size_t FindNodeRef(ArrayNode* array_node, const T& v) { } // The replacer of cache. -class VarReplacer : public ir::IRMutator { +class VarReplacer : public ir::StmtExprMutator { public: explicit VarReplacer( const std::unordered_map& vsub) : vsub_(vsub) {} - Expr Mutate_(const Variable* op, const Expr& e) { + Expr VisitExpr_(const Variable* op) { auto it = vsub_.find(op); if (it != vsub_.end()) return it->second; - return e; + return GetRef(op); } ir::CommReducer MutateCommReducer(ir::CommReducer combiner) { // Replace free variables in combiner auto new_identity = ir::UpdateArray(combiner->identity_element, [this] (const Expr& e) { - return this->Mutate(e); + return this->VisitExpr(e); }); auto new_result = ir::UpdateArray(combiner->result, [this] (const Expr& e) { - return this->Mutate(e); + return this->VisitExpr(e); }); if (combiner->identity_element.same_as(new_identity) && @@ -71,8 +71,8 @@ class VarReplacer : public ir::IRMutator { } } - Expr Mutate_(const ir::Reduce* op, const Expr& e) { - Expr new_e = IRMutator::Mutate_(op, e); + Expr VisitExpr_(const ir::Reduce* op) { + Expr new_e = StmtExprMutator::VisitExpr_(op); const ir::Reduce* new_reduce = new_e.as(); ir::CommReducer new_combiner = MutateCommReducer(op->combiner); if (op->combiner.same_as(new_combiner)) { @@ -316,9 +316,9 @@ Array CacheWriteWithReLayout(Schedule sch, Array body_list; const ir::Reduce* first_reduce = nullptr; for (auto cbody : compute->body) { - body = VarReplacer(vsub).Mutate(cbody); + body = VarReplacer(vsub)(cbody); body = InjectPredicate(predicates, body); - body = VarReplacer(vsub2newvar).Mutate(body); + body = VarReplacer(vsub2newvar)(body); // Reduce nodes in ONE computeOp must be the same except value_index // This is right only if the original body ensures Reduce nodes are the same if (body->IsInstance()) { @@ -404,8 +404,8 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, for (Region old_region : tensor_op->input_regions) { Region region; for (Range r : old_region) { - Expr min = VarReplacer(vsub2newvar).Mutate(r->min); - Expr extent = VarReplacer(vsub2newvar).Mutate(r->extent); + Expr min = VarReplacer(vsub2newvar)(r->min); + Expr extent = VarReplacer(vsub2newvar)(r->extent); region.push_back(Range::make_by_min_extent(min, extent)); } new_regions.push_back(region); @@ -413,7 +413,7 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, Array new_scalar_inputs; for (Expr old_input : tensor_op->scalar_inputs) { - new_scalar_inputs.push_back(VarReplacer(vsub2newvar).Mutate(old_input)); + new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input)); } Operation cache_op = TensorComputeOpNode::make( @@ -786,9 +786,9 @@ Array Schedule::rfactor(const Tensor& tensor, } VarReplacer replacer(vsub); Array new_source = ir::UpdateArray(reduce->source, - [&replacer] (const Expr& e) { return replacer.Mutate(e); }); + [&replacer] (const Expr& e) { return replacer(e); }); - Expr new_pred = replacer.Mutate(predicate); + Expr new_pred = replacer(predicate); std::vector body; for (size_t idx = 0; idx < reduce->source.size(); ++idx) { From 3b5b63b1ccb2e9f320d970264a6a14cf718d9973 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 13:04:46 -0800 Subject: [PATCH 39/45] Migrate auto_inline_elemwise --- src/schedule/auto_inline_elem_wise.cc | 14 +++++++------- src/schedule/bound.cc | 1 - 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/schedule/auto_inline_elem_wise.cc b/src/schedule/auto_inline_elem_wise.cc index e587f385734f..f5cd6c183737 100644 --- a/src/schedule/auto_inline_elem_wise.cc +++ b/src/schedule/auto_inline_elem_wise.cc @@ -22,23 +22,23 @@ */ #include #include -#include +#include namespace tvm { namespace schedule { using namespace ir; -class ElemWiseDetector : public ir::IRVisitor { +class ElemWiseDetector : public ir::ExprVisitor { public: explicit ElemWiseDetector(Array axis) : axis_(axis) {} - void Visit(const ObjectRef& e) final { + void VisitExpr(const Expr& e) final { if (!is_elem_wise_) return; - IRVisitor::Visit(e); + ExprVisitor::VisitExpr(e); } - void Visit_(const Call* op) final { + void VisitExpr_(const Call* op) final { Array axis = op->args; if (axis_.size() != axis.size()) { is_elem_wise_ = false; @@ -51,7 +51,7 @@ class ElemWiseDetector : public ir::IRVisitor { return; } } - IRVisitor::Visit_(op); + ExprVisitor::VisitExpr_(op); } bool is_elem_wise_{true}; @@ -64,7 +64,7 @@ class ElemWiseDetector : public ir::IRVisitor { bool IsElemWise(const Operation& op) { if (const ComputeOpNode* compute = op.as()) { ElemWiseDetector v = ElemWiseDetector(compute->axis); - for (auto& e : compute->body) v.Visit(e); + for (auto& e : compute->body) v(e); return v.is_elem_wise_; } return false; diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index d4baded91f7c..7cf5cff0aff7 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -21,7 +21,6 @@ * \file bound.cc * \brief The bound inference logic. */ -#include #include #include #include From 34888bb32e564b15a49f799f4a64fc3d7d3b0573 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 13:05:53 -0800 Subject: [PATCH 40/45] Remove unecessary ref to visitor --- src/op/tensor_compute_op.cc | 1 - src/pass/ir_functor.cc | 1 - 2 files changed, 2 deletions(-) diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc index cfd6e23a0db4..d82363e496ca 100644 --- a/src/op/tensor_compute_op.cc +++ b/src/op/tensor_compute_op.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include "./op_util.h" diff --git a/src/pass/ir_functor.cc b/src/pass/ir_functor.cc index fae4c03df585..7c06a853390f 100644 --- a/src/pass/ir_functor.cc +++ b/src/pass/ir_functor.cc @@ -21,7 +21,6 @@ */ #include #include -#include namespace tvm { namespace ir { From 0ab8bdcefbc62d9533091f147489f9b40657cf75 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 13:06:26 -0800 Subject: [PATCH 41/45] remove unecessary ref --- src/arithmetic/detect_linear_equation.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index c4ee40f12da8..b8ec974b436c 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -23,7 +23,6 @@ */ #include #include -#include #include #include From bb2256b15f3ce7f96f3e8e1959ea16792d56f541 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 13:09:17 -0800 Subject: [PATCH 42/45] Migrate bound_deducer --- src/arithmetic/bound_deducer.cc | 46 ++++++++++++++++----------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 0b84be291f71..6f98017e6a69 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include @@ -38,17 +38,17 @@ using namespace ir; // a visitor to find the path to the target variable // from a expression. -class VariablePathFinder: public IRVisitor { +class VariablePathFinder: public ExprVisitor { public: explicit VariablePathFinder(Expr target) : target_(target) {} - void Visit(const ObjectRef& node) final { + void VisitExpr(const Expr& node) final { if (visited_.count(node.get()) != 0) return; visited_.insert(node.get()); if (!found_) path_.push_back(node.get()); if (node.same_as(target_)) found_ = true; - IRVisitor::Visit(node); + ExprVisitor::VisitExpr(node); if (!found_) path_.pop_back(); } @@ -64,14 +64,14 @@ class VariablePathFinder: public IRVisitor { // return empty vector to represent failure std::vector GetPath(Expr target, Expr expr) { VariablePathFinder v(target); - v.Visit(expr); + v(expr); return v.path_; } enum CompareOp {kGreater, kLess, kEqual}; // a visitor to deduce the bound of a variable from a expression -class BoundDeducer: public IRVisitor { +class BoundDeducer: public ExprVisitor { public: friend class BoundDeduceInputChecker; friend class Converter; @@ -82,39 +82,39 @@ class BoundDeducer: public IRVisitor { void Deduce(); - void Visit(const ObjectRef& e) final { + void VisitExpr(const Expr& e) final { if (!success_) return; if (e.get() == path_[iter_++]) { - IRVisitor::Visit(e); + ExprVisitor::VisitExpr(e); } else { success_ = false; return; } } - void Visit_(const LT* op) final { + void VisitExpr_(const LT* op) final { LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } - void Visit_(const LE* op) final { + void VisitExpr_(const LE* op) final { LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } - void Visit_(const GT* op) final { + void VisitExpr_(const GT* op) final { LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } - void Visit_(const GE* op) final { + void VisitExpr_(const GE* op) final { LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } - void Visit_(const Add* op) final { + void VisitExpr_(const Add* op) final { bool left = op->a.get() == path_[iter_]; result_ -= left ? op->b : op->a; - Visit(left ? op->a : op->b); + this->VisitExpr(left ? op->a : op->b); } - void Visit_(const Sub* op) final { + void VisitExpr_(const Sub* op) final { bool left = op->a.get() == path_[iter_]; if (left) { result_ += op->b; @@ -123,10 +123,10 @@ class BoundDeducer: public IRVisitor { result_ = - result_; comp_op = ReverseOp(comp_op); } - Visit(left ? op->a : op->b); + this->VisitExpr(left ? op->a : op->b); } - void Visit_(const Mul* op) final { + void VisitExpr_(const Mul* op) final { bool left = op->a.get() == path_[iter_]; Expr operand = left ? op->b : op->a; Expr target_var = left ? op->a : op->b; @@ -171,7 +171,7 @@ class BoundDeducer: public IRVisitor { // ( x <= -3/-2 --> x <= 1) } } - Visit(left ? op->a : op->b); + this->VisitExpr(left ? op->a : op->b); } Expr result_; @@ -194,17 +194,17 @@ class BoundDeducer: public IRVisitor { Analyzer analyzer_; }; -class BoundDeduceInputChecker: public IRVisitor { +class BoundDeduceInputChecker: public ExprVisitor { public: bool Check(BoundDeducer* deducer) { deducer_ = deducer; - Visit(deducer_->expr_); + this->VisitExpr(deducer_->expr_); return target_count == 1; } - void Visit(const ObjectRef& e) final { + void VisitExpr(const Expr& e) final { if (e.same_as(deducer_->target_)) ++target_count; - IRVisitor::Visit(e); + ExprVisitor::VisitExpr(e); } private: @@ -305,7 +305,7 @@ void BoundDeducer::Deduce() { } expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); - Visit(expr_); + this->VisitExpr(expr_); } void BoundDeducer::Relax() { From 9209ccada77e28abe006424154934122b786750e Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 13:10:47 -0800 Subject: [PATCH 43/45] Migrate domain_touched --- src/arithmetic/domain_touched.cc | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/arithmetic/domain_touched.cc b/src/arithmetic/domain_touched.cc index 947f0050c6cb..bdd5daa6265c 100644 --- a/src/arithmetic/domain_touched.cc +++ b/src/arithmetic/domain_touched.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -23,7 +23,7 @@ */ #include #include -#include +#include #include #include @@ -36,13 +36,13 @@ namespace arith { using namespace ir; // Find Read region of the tensor in the stmt. -class FuncTouchedDomain final : public IRVisitor { +class FuncTouchedDomain final : public StmtExprVisitor { public: FuncTouchedDomain(const Tensor &tensor, bool consider_calls, bool consider_provides) : tensor_(tensor), consider_calls_(consider_calls), consider_provides_(consider_provides) {} Domain Find(const Stmt& stmt) { - this->Visit(stmt); + operator()(stmt); Domain ret; Range none; for (size_t i = 0; i < bounds_.size(); ++i) { @@ -51,49 +51,49 @@ class FuncTouchedDomain final : public IRVisitor { return ret; } - void Visit_(const For *op) final { + void VisitStmt_(const For *op) final { const Variable* var = op->loop_var.get(); dom_map_[var] = IntSet::range( Range::make_by_min_extent(op->min, op->extent)); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); dom_map_.erase(var); } - void Visit_(const LetStmt* op) final { + void VisitStmt_(const LetStmt* op) final { dom_map_[op->var.get()] = arith::EvalSet(op->value, dom_map_); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); dom_map_.erase(op->var.get()); } /* TODO: Thread extent unitest not generated.*/ - void Visit_(const AttrStmt* op) final { + void VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::thread_extent) { const IterVarNode* thread_axis = op->node.as(); CHECK(thread_axis); const Variable* var = thread_axis->var.get(); dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value)); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); dom_map_.erase(var); } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } } - void Visit_(const Call* op) final { + void VisitExpr_(const Call* op) final { if (consider_calls_ && tensor_->op.same_as(op->func) && tensor_->value_index == op->value_index) { Touch(op->args); } - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); } - void Visit_(const Provide* op) final { + void VisitStmt_(const Provide* op) final { if (consider_provides_ && tensor_->op.same_as(op->func) && tensor_->value_index == op->value_index) { Touch(op->args); } - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } private: From 1645d557f885e3350ae637744028bf8ba941be78 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 13:16:09 -0800 Subject: [PATCH 44/45] Migrate autotvm feature touch extractor --- src/autotvm/feature_visitor.cc | 22 +++++++++++----------- src/autotvm/feature_visitor.h | 19 +++++++++++-------- src/autotvm/touch_extractor.cc | 10 +++++----- src/autotvm/touch_extractor.h | 32 ++++++++++++++++---------------- 4 files changed, 43 insertions(+), 40 deletions(-) diff --git a/src/autotvm/feature_visitor.cc b/src/autotvm/feature_visitor.cc index a75164cb8b15..4d2330fa9220 100644 --- a/src/autotvm/feature_visitor.cc +++ b/src/autotvm/feature_visitor.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,7 +29,7 @@ namespace tvm { namespace autotvm { // for loop -void FeatureVisitor::Visit_(const For *op) { +void FeatureVisitor::VisitStmt_(const For* op) { const auto *extent = op->extent.as(); int64_t loop_extent = -1; if (extent != nullptr) @@ -51,13 +51,13 @@ void FeatureVisitor::Visit_(const For *op) { } if (EnterItervar_(op->loop_var, loop_extent, ann)) { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); ExitItervar_(); } } // parallel axis, virtual thread -void FeatureVisitor::Visit_(const AttrStmt *op) { +void FeatureVisitor::VisitStmt_(const AttrStmt* op) { if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { VarExpr var = op->node.as()->var; @@ -86,24 +86,24 @@ void FeatureVisitor::Visit_(const AttrStmt *op) { } if (EnterItervar_(var, extent->value, ann)) { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); ExitItervar_(); } } else { - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); } } // memory access -void FeatureVisitor::Visit_(const Load *op) { +void FeatureVisitor::VisitExpr_(const Load* op) { EnterMem_(op->buffer_var, op->index); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitExpr_(op); ExitMem_(); } -void FeatureVisitor::Visit_(const Store *op) { +void FeatureVisitor::VisitStmt_(const Store* op) { EnterMem_(op->buffer_var, op->index); - IRVisitor::Visit_(op); + StmtExprVisitor::VisitStmt_(op); ExitMem_(); } diff --git a/src/autotvm/feature_visitor.h b/src/autotvm/feature_visitor.h index a14c934756f1..32d4e092ce75 100644 --- a/src/autotvm/feature_visitor.h +++ b/src/autotvm/feature_visitor.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -27,7 +27,7 @@ #define TVM_AUTOTVM_FEATURE_VISITOR_H_ #include -#include +#include #include namespace tvm { @@ -48,15 +48,18 @@ enum AnnotationType { * \brief A base class for feature extractor, used for processing * for loop and memory access in the IR */ -class FeatureVisitor : public IRVisitor { +class FeatureVisitor : public StmtExprVisitor { public: // for loop - void Visit_(const For *op); - void Visit_(const AttrStmt *op); + void VisitStmt_(const For *op); + void VisitStmt_(const AttrStmt *op); // memory access - void Visit_(const Load *op); - void Visit_(const Store *op); + void VisitExpr_(const Load *op); + void VisitStmt_(const Store *op); + + using StmtExprVisitor::VisitStmt_; + using StmtExprVisitor::VisitExpr_; protected: /*! diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index f66a724595c6..fcb1d611c3b0 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -44,14 +44,14 @@ int ParallelLevel(AnnotationType ann) { } // get touch pattern from index expression -class IndexParser: public IRVisitor { +class IndexParser: public ExprVisitor { public: void Parse(Expr expr) { pattern_map.clear(); - this->Visit(expr); + this->VisitExpr(expr); } - void Visit_(const Variable *op) { + void VisitExpr_(const Variable *op) { // TODO(lmzheng): handle more index types (multiple occurrence) if (pattern_map.count(op) == 0) { pattern_map[op] = TouchPattern(); @@ -60,13 +60,13 @@ class IndexParser: public IRVisitor { } } - void Visit_(const Mul *op) { + void VisitExpr_(const Mul *op) { if (op->a.as()) { if (const auto stride = op->b.as()) { next_stride_ = stride->value; } } - IRVisitor::Visit_(op); + ExprVisitor::VisitExpr_(op); } std::unordered_map pattern_map; diff --git a/src/autotvm/touch_extractor.h b/src/autotvm/touch_extractor.h index 1028b0144e12..027788cfaf03 100644 --- a/src/autotvm/touch_extractor.h +++ b/src/autotvm/touch_extractor.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -26,7 +26,7 @@ #define TVM_AUTOTVM_TOUCH_EXTRACTOR_H_ #include -#include +#include #include #include #include @@ -85,39 +85,39 @@ struct ItervarFeature { // extract iter vars and their touch pattern from ir class TouchExtractor : public FeatureVisitor { public: - void Analyze(Stmt stmt) { - this->Visit(stmt); + void Analyze(const Stmt& stmt) { + operator()(stmt); } // arithmetic stats - void Visit_(const Add *op) { + void VisitExpr_(const Add *op) { if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; - IRVisitor::Visit_(op); + FeatureVisitor::VisitExpr_(op); } - void Visit_(const Sub *op) { + void VisitExpr_(const Sub *op) { if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; - IRVisitor::Visit_(op); + FeatureVisitor::VisitExpr_(op); } - void Visit_(const Mul *op) { + void VisitExpr_(const Mul *op) { if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].mul_ct++; - IRVisitor::Visit_(op); + FeatureVisitor::VisitExpr_(op); } - void Visit_(const Div *op) { + void VisitExpr_(const Div *op) { if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; - IRVisitor::Visit_(op); + FeatureVisitor::VisitExpr_(op); } - void Visit_(const Mod *op) { + void VisitExpr_(const Mod *op) { if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; - IRVisitor::Visit_(op); + FeatureVisitor::VisitExpr_(op); } std::unordered_map itervar_map; @@ -134,7 +134,7 @@ class TouchExtractor : public FeatureVisitor { std::deque itervar_stack_; // use deque instead of stack for indexing std::deque skip_stack_size_; - using IRVisitor::Visit_; + using FeatureVisitor::VisitExpr_; }; } // namespace autotvm From b48e228072f1da1fad4178c7cd7f10db6ef1f6d2 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 2 Jan 2020 14:22:00 -0800 Subject: [PATCH 45/45] Add annotations --- src/op/hybrid_op.cc | 5 ++--- src/op/op_util.cc | 2 +- src/pass/inject_virtual_thread.cc | 10 +++++----- src/pass/verify_gpu_code.cc | 6 +++--- src/pass/verify_memory.cc | 10 +++++----- src/schedule/schedule_dataflow_rewrite.cc | 4 ++-- 6 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc index f0bc200732fd..4de5f1cff18d 100644 --- a/src/op/hybrid_op.cc +++ b/src/op/hybrid_op.cc @@ -280,8 +280,7 @@ Stmt ApplyLoopShapes(const Stage &stage, extent(0), fused(false) {} // TODO(@were): Handle imperfect loops - - Stmt VisitStmt_(const For *op) { + Stmt VisitStmt_(const For* op) final { if (op->loop_var.get() == inner) { CHECK(under_outer); std::unordered_map rmap; @@ -422,7 +421,7 @@ Stmt ApplyLoopOrder(const Stage &stage, const std::unordered_map &reorder) : stage(stage), dom_map(dom_map), reorder(reorder) {} - Stmt VisitStmt_(const For *op) final { + Stmt VisitStmt_(const For* op) final { // Reorder from in to out Stmt body_ = this->VisitStmt(op->body); CHECK(reorder.count(op->loop_var.get())); diff --git a/src/op/op_util.cc b/src/op/op_util.cc index 801538ddf914..4a6d0d2f302a 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -191,7 +191,7 @@ class TensorReplacer : public ir::StmtExprMutator { explicit TensorReplacer(const std::unordered_map& vmap) : vmap_(vmap) {} - Expr VisitExpr_(const ir::Call* op) { + Expr VisitExpr_(const ir::Call* op) final { if (op->call_type == ir::Call::Halide) { Tensor t = Downcast(op->func).output(op->value_index); auto it = vmap_.find(t); diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc index ac94d57f1829..202a5c27bd8b 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/pass/inject_virtual_thread.cc @@ -96,19 +96,19 @@ class ExprTouched final : public StmtExprVisitor { // Analyze if the buffers are invariant to value of var class VarTouchedAnalysis : public StmtVisitor { public: - void VisitStmt_(const LetStmt* op) { + void VisitStmt_(const LetStmt* op) final { ExprTouched tc(touched_var_, false); tc(op->value); Record(op->var.get(), tc); this->VisitStmt(op->body); } - void VisitStmt_(const Store *op) { + void VisitStmt_(const Store* op) final { ExprTouched tc(touched_var_, false); tc(op->value); tc(op->index); Record(op->buffer_var.get(), tc); } - void VisitStmt_(const For *op) { + void VisitStmt_(const For* op) final { ExprTouched tc(touched_var_, false); tc(op->min); tc(op->extent); @@ -116,14 +116,14 @@ class VarTouchedAnalysis : public StmtVisitor { this->VisitStmt(op->body); } // external function call - void VisitStmt_(const Evaluate *op) { + void VisitStmt_(const Evaluate* op) final { ExprTouched tc(touched_var_, true); tc(op->value); for (const Variable* var : tc.write_vars_) { Record(var, tc); } } - void VisitStmt_(const Allocate *op) { + void VisitStmt_(const Allocate* op) final { ExprTouched tc(touched_var_, false); for (size_t i = 0; i < op->extents.size(); ++i) { tc(op->extents[i]); diff --git a/src/pass/verify_gpu_code.cc b/src/pass/verify_gpu_code.cc index 36871679a681..1adc6851caf4 100644 --- a/src/pass/verify_gpu_code.cc +++ b/src/pass/verify_gpu_code.cc @@ -54,7 +54,7 @@ class GPUCodeVerifier : public StmtVisitor { return valid_; } - void VisitStmt_(const ProducerConsumer *op) { + void VisitStmt_(const ProducerConsumer* op) final { if (nest_level_ == 0) { // enter a new kernel, reset statistics Reset_(); @@ -77,7 +77,7 @@ class GPUCodeVerifier : public StmtVisitor { } } - void VisitStmt_(const Allocate *op) { + void VisitStmt_(const Allocate* op) final { StmtVisitor::VisitStmt_(op); // visit an allocation of a buffer in shared memory, record its size if (visited_local_buffers_.count(op->buffer_var.get()) != 0) { @@ -89,7 +89,7 @@ class GPUCodeVerifier : public StmtVisitor { } } - void VisitStmt_(const AttrStmt *op) { + void VisitStmt_(const AttrStmt* op) final { if (op->attr_key == attr::storage_scope) { std::string op_value = op->value.as()->value; if (op_value == "local") { diff --git a/src/pass/verify_memory.cc b/src/pass/verify_memory.cc index 6e5e02d3d79b..415841d04017 100644 --- a/src/pass/verify_memory.cc +++ b/src/pass/verify_memory.cc @@ -75,13 +75,13 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { StmtExprVisitor::VisitStmt(n); } - void VisitStmt_(const LetStmt *op) final { + void VisitStmt_(const LetStmt* op) final { // Book keep definitions defs_[op->var.get()] = op->value; return StmtExprVisitor::VisitStmt_(op); } - void VisitStmt_(const AttrStmt *op) final { + void VisitStmt_(const AttrStmt* op) final { if (!InThreadEnv() && (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope)) { EnterThreadEnv(); @@ -92,18 +92,18 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } } - void VisitStmt_(const ProducerConsumer *op) final { + void VisitStmt_(const ProducerConsumer* op) final { EnterProducerConsumer(op); StmtExprVisitor::VisitStmt_(op); ExitProducerConsumer(); } - void VisitExpr_(const Load *op) final { + void VisitExpr_(const Load* op) final { HandleLoadStoreToVariable(op->buffer_var); return StmtExprVisitor::VisitExpr_(op); } - void VisitStmt_(const Store *op) final { + void VisitStmt_(const Store* op) final { HandleLoadStoreToVariable(op->buffer_var); return StmtExprVisitor::VisitStmt_(op); } diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index fe5ea1e21875..9aef563fbefc 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -47,7 +47,7 @@ class VarReplacer : public ir::StmtExprMutator { explicit VarReplacer( const std::unordered_map& vsub) : vsub_(vsub) {} - Expr VisitExpr_(const Variable* op) { + Expr VisitExpr_(const Variable* op) final { auto it = vsub_.find(op); if (it != vsub_.end()) return it->second; return GetRef(op); @@ -71,7 +71,7 @@ class VarReplacer : public ir::StmtExprMutator { } } - Expr VisitExpr_(const ir::Reduce* op) { + Expr VisitExpr_(const ir::Reduce* op) final { Expr new_e = StmtExprMutator::VisitExpr_(op); const ir::Reduce* new_reduce = new_e.as(); ir::CommReducer new_combiner = MutateCommReducer(op->combiner);