diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 70ab7688c450..47788394126e 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -186,7 +186,7 @@ class TVM_DLL Object { template inline bool IsInstance() const; /*! - * \return Weather the cell has only one reference + * \return Whether the cell has only one reference * \note We use stl style naming to be consistent with known API in shared_ptr. */ inline bool unique() const; @@ -337,7 +337,7 @@ inline RelayRefType GetRef(const ObjectType* ptr); /*! * \brief Downcast a base reference type to a more specific type. * - * \param ref The inptut reference + * \param ref The input reference * \return The corresponding SubRef. * \tparam SubRef The target specific reference type. * \tparam BaseRef the current reference type. @@ -416,7 +416,7 @@ class ObjectPtr { return *get(); } /*! - * \brief copy assignmemt + * \brief copy assignment * \param other The value to be assigned. * \return reference to self. */ @@ -427,7 +427,7 @@ class ObjectPtr { return *this; } /*! - * \brief move assignmemt + * \brief move assignment * \param other The value to be assigned. * \return reference to self. */ @@ -632,7 +632,7 @@ struct ObjectPtrEqual { }; /*! - * \brief helper macro to declare a base object type that can be inheritated. + * \brief helper macro to declare a base object type that can be inherited. * \param TypeName The name of the current type. * \param ParentType The name of the ParentType */ @@ -648,10 +648,10 @@ struct ObjectPtrEqual { return _GetOrAllocRuntimeTypeIndex(); \ } \ static uint32_t _GetOrAllocRuntimeTypeIndex() { \ - static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex( \ + static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex( \ TypeName::_type_key, TypeName::_type_index, ParentType::_GetOrAllocRuntimeTypeIndex(), \ TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow); \ - return tidx; \ + return tindex; \ } /*! @@ -664,7 +664,7 @@ struct ObjectPtrEqual { static const constexpr int _type_child_slots = 0; \ TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) -/*! \brief helper macro to supress unused warning */ +/*! \brief helper macro to suppress unused warning */ #if defined(__GNUC__) #define TVM_ATTRIBUTE_UNUSED __attribute__((unused)) #else @@ -686,7 +686,7 @@ struct ObjectPtrEqual { TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = TypeName::_GetOrAllocRuntimeTypeIndex() /* - * \brief Define the default copy/move constructor and assign opeator + * \brief Define the default copy/move constructor and assign operator * \param TypeName The class typename. */ #define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ @@ -827,7 +827,7 @@ inline bool Object::IsInstance() const { if (!TargetType::_type_child_slots_can_overflow) return false; // Invariance: parent index is always smaller than the child. if (self->type_index_ < TargetType::RuntimeTypeIndex()) return false; - // The rare slower-path, check type hierachy. + // The rare slower-path, check type hierarchy. return self->DerivedFrom(TargetType::RuntimeTypeIndex()); } } else { diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index e5b2c2b6957c..1ad78596586a 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -56,6 +56,22 @@ struct ExprDeepEqual { TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const; }; +/*! + * \brief Visit the PrimFuncs in the IRModule + * \tparam FLambda The type of the PrimFunc visitor + * \param mod The IRModule to be visited + * \param fvisit The visitor to the PrimFuncs in the IRModule + */ +template +inline void VisitPrimFuncs(const IRModule& mod, FLambda fvisit) { + for (const auto& kv : mod->functions) { + const BaseFunc& base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + fvisit(prim_func); + } + } +} + /*! * \brief Find undefined vars in the statement. * \param stmt The function to be checked. diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index d6303ae266e1..c1c618f0c22f 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -386,6 +386,15 @@ inline T Substitute(T input, const std::unordered_map& return Substitute(std::move(input), vmap); } +/*! + * \brief Recursively visit the IR in pre DFS order node, apply fvisit. + * If fvisit returns false, it won't visit the children of the node. + * \param stmt_or_expr The ir to be visited. + * \param fvisit The visitor function to be applied. If fvisit returns false, it won't visit the + * children of the node + */ +TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr, + const std::function& fvisit); } // namespace tir } // namespace tvm diff --git a/include/tvm/topi/detail/constant_utils.h b/include/tvm/topi/detail/constant_utils.h index 92ff3a4e3804..95e68f5f6d61 100644 --- a/include/tvm/topi/detail/constant_utils.h +++ b/include/tvm/topi/detail/constant_utils.h @@ -119,12 +119,11 @@ inline std::vector GetConstInt64Values(Array exprs, } /*! - * \brief Check weather the two expressions are equal or not, if not simplify the expressions and - * check again \note This is stronger equality check than tvm::tir::Equal - * - * \param lhs First expreesion - * \param rhs Second expreesion - * + * \brief Check whether the two expressions are equal or not, if not simplify the expressions and + * check again + * \note This is stronger equality check than tvm::tir::Equal + * \param lhs First expression + * \param rhs Second expression * \return result True if both expressions are equal, else false */ inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) { diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 639d38db0a81..07574e4fb2f1 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -19,12 +19,14 @@ /*! * \file stmt_functor.cc */ +#include #include +#include #include #include -#include "functor_common.h" +#include "./functor_common.h" namespace tvm { namespace tir { @@ -631,9 +633,9 @@ Stmt IRTransform(Stmt ir_node, const runtime::PackedFunc& f_preorder, return transform(std::move(ir_node)); } -class IRSubstitue : public StmtExprMutator { +class IRSubstitute : public StmtExprMutator { public: - explicit IRSubstitue(std::function(const Var&)> vmap) : vmap_(vmap) {} + explicit IRSubstitute(std::function(const Var&)> vmap) : vmap_(vmap) {} PrimExpr VisitExpr_(const VarNode* op) final { Var var = GetRef(op); @@ -679,11 +681,53 @@ class IRSubstitue : public StmtExprMutator { }; Stmt Substitute(Stmt stmt, std::function(const Var&)> vmap) { - return IRSubstitue(vmap)(std::move(stmt)); + return IRSubstitute(vmap)(std::move(stmt)); } PrimExpr Substitute(PrimExpr expr, std::function(const Var&)> vmap) { - return IRSubstitue(vmap)(std::move(expr)); + return IRSubstitute(vmap)(std::move(expr)); +} + +void PreOrderVisit(const ObjectRef& stmt_or_expr, + const std::function& fvisit) { + class PreOrderVisitor : public StmtExprVisitor { + public: + explicit PreOrderVisitor(const std::function& f) : f_(f) {} + + private: + void VisitExpr(const PrimExpr& expr) final { + const PrimExprNode* p_expr = expr.get(); + if (visited_.count(p_expr) == 0) { + visited_.insert(p_expr); + if (f_(expr)) { + ExprVisitor::VisitExpr(expr); + } + } + } + + void VisitStmt(const Stmt& stmt) final { + const StmtNode* p_stmt = stmt.get(); + if (visited_.count(p_stmt) == 0) { + visited_.insert(p_stmt); + if (f_(stmt)) { + StmtVisitor::VisitStmt(stmt); + } + } + } + + const std::function& f_; + std::unordered_set visited_; + }; + + PreOrderVisitor visitor(fvisit); + if (const auto* stmt = stmt_or_expr.as()) { + visitor(GetRef(stmt)); + } else if (const auto* expr = stmt_or_expr.as()) { + visitor(GetRef(expr)); + } else { + LOG(FATAL) << "InternalError: PreOrderVisit does not accept object with type: " + << stmt_or_expr->GetTypeKey(); + } } TVM_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform); diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 237dc46b99ca..1f7d18f747ea 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -19,10 +19,14 @@ #include #include +#include #include +#include +#include #include #include #include +#include #include #include @@ -52,6 +56,55 @@ TEST(IRF, CountVar) { ICHECK_EQ(n_var, 2); } +TEST(IRF, VisitPrimFuncs) { + using namespace tvm; + using namespace tvm::tir; + PrimFunc prim_func(/*params=*/{}, /*body=*/Evaluate(Integer(0))); + relay::Function relay_func(/*params=*/{}, /*body=*/relay::Expr(nullptr), + /*ret_type=*/relay::Type{nullptr}, /*ty_params=*/{}); + IRModule mod({ + {GlobalVar("main"), prim_func}, + {GlobalVar("main2"), relay_func}, + }); + int n_visited = 0; + VisitPrimFuncs(mod, [&](const PrimFuncNode* func) { ++n_visited; }); + ASSERT_EQ(n_visited, 1); +} + +TEST(IRF, PreOrderVisit) { + using namespace tvm; + using namespace tvm::tir; + Stmt init = IfThenElse(const_true(), Evaluate(Integer(0)), Evaluate(Integer(0))); + Stmt body = Evaluate(Integer(1)); + Block block(/*iter_vars=*/{}, /*reads=*/{}, + /*writes=*/{}, /*name_hint=*/"block", /*body=*/body, + /*init=*/init); + bool init_visited = false; + bool stopped_at_if = true; + bool body_visited = false; + PreOrderVisit(block, [&](const ObjectRef& n) -> bool { + if (n->IsInstance()) { + init_visited = true; + return false; + } + if (const auto* eval = n.as()) { + if (const auto* int_imm = eval->value.as()) { + if (int_imm->value == 0) { + stopped_at_if = false; + } else if (int_imm->value == 1) { + body_visited = true; + } else { + LOG(FATAL) << "Unreachable"; + } + } + } + return true; + }); + ASSERT_EQ(init_visited, true); + ASSERT_EQ(stopped_at_if, true); + ASSERT_EQ(body_visited, true); +} + TEST(IRF, ExprTransform) { using namespace tvm; using namespace tvm::tir;