Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class TVM_DLL Object {
template <typename TargetType>
inline bool IsInstance() const;
/*!
* \return Weather the cell has only one reference
* \return Whether the cell has only one reference
* \note We use stl style naming to be consistent with known API in shared_ptr.
*/
inline bool unique() const;
Expand Down Expand Up @@ -337,7 +337,7 @@ inline RelayRefType GetRef(const ObjectType* ptr);
/*!
* \brief Downcast a base reference type to a more specific type.
*
* \param ref The inptut reference
* \param ref The input reference
* \return The corresponding SubRef.
* \tparam SubRef The target specific reference type.
* \tparam BaseRef the current reference type.
Expand Down Expand Up @@ -416,7 +416,7 @@ class ObjectPtr {
return *get();
}
/*!
* \brief copy assignmemt
* \brief copy assignment
* \param other The value to be assigned.
* \return reference to self.
*/
Expand All @@ -427,7 +427,7 @@ class ObjectPtr {
return *this;
}
/*!
* \brief move assignmemt
* \brief move assignment
* \param other The value to be assigned.
* \return reference to self.
*/
Expand Down Expand Up @@ -632,7 +632,7 @@ struct ObjectPtrEqual {
};

/*!
* \brief helper macro to declare a base object type that can be inheritated.
* \brief helper macro to declare a base object type that can be inherited.
* \param TypeName The name of the current type.
* \param ParentType The name of the ParentType
*/
Expand All @@ -648,10 +648,10 @@ struct ObjectPtrEqual {
return _GetOrAllocRuntimeTypeIndex(); \
} \
static uint32_t _GetOrAllocRuntimeTypeIndex() { \
static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex( \
static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex( \
TypeName::_type_key, TypeName::_type_index, ParentType::_GetOrAllocRuntimeTypeIndex(), \
TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow); \
return tidx; \
return tindex; \
}

/*!
Expand All @@ -664,7 +664,7 @@ struct ObjectPtrEqual {
static const constexpr int _type_child_slots = 0; \
TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)

/*! \brief helper macro to supress unused warning */
/*! \brief helper macro to suppress unused warning */
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
Expand All @@ -686,7 +686,7 @@ struct ObjectPtrEqual {
TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = TypeName::_GetOrAllocRuntimeTypeIndex()

/*
* \brief Define the default copy/move constructor and assign opeator
* \brief Define the default copy/move constructor and assign operator
* \param TypeName The class typename.
*/
#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \
Expand Down Expand Up @@ -827,7 +827,7 @@ inline bool Object::IsInstance() const {
if (!TargetType::_type_child_slots_can_overflow) return false;
// Invariance: parent index is always smaller than the child.
if (self->type_index_ < TargetType::RuntimeTypeIndex()) return false;
// The rare slower-path, check type hierachy.
// The rare slower-path, check type hierarchy.
return self->DerivedFrom(TargetType::RuntimeTypeIndex());
}
} else {
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,22 @@ struct ExprDeepEqual {
TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
};

/*!
* \brief Visit the PrimFuncs in the IRModule
* \tparam FLambda The type of the PrimFunc visitor
* \param mod The IRModule to be visited
* \param fvisit The visitor to the PrimFuncs in the IRModule
*/
template <class FLambda>
inline void VisitPrimFuncs(const IRModule& mod, FLambda fvisit) {
for (const auto& kv : mod->functions) {
const BaseFunc& base_func = kv.second;
if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
fvisit(prim_func);
}
}
}

/*!
* \brief Find undefined vars in the statement.
* \param stmt The function to be checked.
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,15 @@ inline T Substitute(T input, const std::unordered_map<const VarNode*, PrimExpr>&
return Substitute(std::move(input), vmap);
}

/*!
* \brief Recursively visit the IR in pre DFS order node, apply fvisit.
* If fvisit returns false, it won't visit the children of the node.
* \param stmt_or_expr The ir to be visited.
* \param fvisit The visitor function to be applied. If fvisit returns false, it won't visit the
* children of the node
*/
TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr,
const std::function<bool(const ObjectRef&)>& fvisit);
} // namespace tir
} // namespace tvm

Expand Down
11 changes: 5 additions & 6 deletions include/tvm/topi/detail/constant_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,11 @@ inline std::vector<int64_t> GetConstInt64Values(Array<PrimExpr> exprs,
}

/*!
* \brief Check weather the two expressions are equal or not, if not simplify the expressions and
* check again \note This is stronger equality check than tvm::tir::Equal
*
* \param lhs First expreesion
* \param rhs Second expreesion
*
* \brief Check whether the two expressions are equal or not, if not simplify the expressions and
* check again
* \note This is stronger equality check than tvm::tir::Equal
* \param lhs First expression
* \param rhs Second expression
* \return result True if both expressions are equal, else false
*/
inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) {
Expand Down
54 changes: 49 additions & 5 deletions src/tir/ir/stmt_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
/*!
* \file stmt_functor.cc
*/
#include <tvm/ir/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>

#include <functional>

#include "functor_common.h"
#include "./functor_common.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -631,9 +633,9 @@ Stmt IRTransform(Stmt ir_node, const runtime::PackedFunc& f_preorder,
return transform(std::move(ir_node));
}

class IRSubstitue : public StmtExprMutator {
class IRSubstitute : public StmtExprMutator {
public:
explicit IRSubstitue(std::function<Optional<PrimExpr>(const Var&)> vmap) : vmap_(vmap) {}
explicit IRSubstitute(std::function<Optional<PrimExpr>(const Var&)> vmap) : vmap_(vmap) {}

PrimExpr VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
Expand Down Expand Up @@ -679,11 +681,53 @@ class IRSubstitue : public StmtExprMutator {
};

Stmt Substitute(Stmt stmt, std::function<Optional<PrimExpr>(const Var&)> vmap) {
return IRSubstitue(vmap)(std::move(stmt));
return IRSubstitute(vmap)(std::move(stmt));
}

PrimExpr Substitute(PrimExpr expr, std::function<Optional<PrimExpr>(const Var&)> vmap) {
return IRSubstitue(vmap)(std::move(expr));
return IRSubstitute(vmap)(std::move(expr));
}

void PreOrderVisit(const ObjectRef& stmt_or_expr,
const std::function<bool(const ObjectRef&)>& fvisit) {
class PreOrderVisitor : public StmtExprVisitor {
public:
explicit PreOrderVisitor(const std::function<bool(const ObjectRef&)>& f) : f_(f) {}

private:
void VisitExpr(const PrimExpr& expr) final {
const PrimExprNode* p_expr = expr.get();
if (visited_.count(p_expr) == 0) {
visited_.insert(p_expr);
if (f_(expr)) {
ExprVisitor::VisitExpr(expr);
}
}
}

void VisitStmt(const Stmt& stmt) final {
const StmtNode* p_stmt = stmt.get();
if (visited_.count(p_stmt) == 0) {
visited_.insert(p_stmt);
if (f_(stmt)) {
StmtVisitor::VisitStmt(stmt);
}
}
}

const std::function<bool(const ObjectRef&)>& f_;
std::unordered_set<const Object*> visited_;
};

PreOrderVisitor visitor(fvisit);
if (const auto* stmt = stmt_or_expr.as<StmtNode>()) {
visitor(GetRef<Stmt>(stmt));
} else if (const auto* expr = stmt_or_expr.as<PrimExprNode>()) {
visitor(GetRef<PrimExpr>(expr));
} else {
LOG(FATAL) << "InternalError: PreOrderVisit does not accept object with type: "
<< stmt_or_expr->GetTypeKey();
}
}

TVM_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform);
Expand Down
53 changes: 53 additions & 0 deletions tests/cpp/ir_functor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@

#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/ir/module.h>
#include <tvm/node/functor.h>
#include <tvm/relay/function.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

Expand Down Expand Up @@ -52,6 +56,55 @@ TEST(IRF, CountVar) {
ICHECK_EQ(n_var, 2);
}

TEST(IRF, VisitPrimFuncs) {
using namespace tvm;
using namespace tvm::tir;
PrimFunc prim_func(/*params=*/{}, /*body=*/Evaluate(Integer(0)));
relay::Function relay_func(/*params=*/{}, /*body=*/relay::Expr(nullptr),
/*ret_type=*/relay::Type{nullptr}, /*ty_params=*/{});
IRModule mod({
{GlobalVar("main"), prim_func},
{GlobalVar("main2"), relay_func},
});
int n_visited = 0;
VisitPrimFuncs(mod, [&](const PrimFuncNode* func) { ++n_visited; });
ASSERT_EQ(n_visited, 1);
}

TEST(IRF, PreOrderVisit) {
using namespace tvm;
using namespace tvm::tir;
Stmt init = IfThenElse(const_true(), Evaluate(Integer(0)), Evaluate(Integer(0)));
Stmt body = Evaluate(Integer(1));
Block block(/*iter_vars=*/{}, /*reads=*/{},
/*writes=*/{}, /*name_hint=*/"block", /*body=*/body,
/*init=*/init);
bool init_visited = false;
bool stopped_at_if = true;
bool body_visited = false;
PreOrderVisit(block, [&](const ObjectRef& n) -> bool {
if (n->IsInstance<IfThenElseNode>()) {
init_visited = true;
return false;
}
if (const auto* eval = n.as<EvaluateNode>()) {
if (const auto* int_imm = eval->value.as<IntImmNode>()) {
if (int_imm->value == 0) {
stopped_at_if = false;
} else if (int_imm->value == 1) {
body_visited = true;
} else {
LOG(FATAL) << "Unreachable";
}
}
}
return true;
});
ASSERT_EQ(init_visited, true);
ASSERT_EQ(stopped_at_if, true);
ASSERT_EQ(body_visited, true);
}

TEST(IRF, ExprTransform) {
using namespace tvm;
using namespace tvm::tir;
Expand Down