Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <tvm/relay/function.h>
#include <tvm/relay/op.h>

#include <stack>
#include <string>
#include <unordered_map>
#include <utility>
Expand Down Expand Up @@ -408,6 +409,73 @@ Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter);
*/
void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);

/*!
* \brief A function to iteratively traverse dataflow regions of a graph
*
* ExpandDataflow manually manages a stack and performs DFS to determine the processing
* order of nodes in an input graph.
*
* If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node
* need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack
* and continues iteratively to process the top of the stack. When it finds a node that doesn't
* match the dataflow types, or a node who's inputs have all been processed, it visits the current
* leaf via fvisit_leaf.
*
* This function should be used internally to other classes to implement mixed-mode traversals. The
* expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it
* hits a non-dataflow node.
*
* fcheck_visited and fvisit_leaf are templated to encourage compiler inlining.
*/
template <typename FCheckVisited, typename FVisitLeaf>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to expose this in the header?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a template function and we need to use it in type_infer.cc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InferType directly extends ExprFunctor instead of using ExprVisitor or ExprMutator

void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
std::stack<std::pair<Expr, bool>> stack;
auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) {
// The second state of the stack indicate whether the child has been
// expanded in the pre-order.
// NOTE: function will be inlined.
if (!fcheck_visited(expr)) {
stack.push({expr, false});
}
};
fpush_to_stack(expr);
while (stack.size() > 0) {
auto node = stack.top().first;
if (fcheck_visited(node)) {
// if this node was visited through another path
// after being added to the stack ignore it.
stack.pop();
} else if (stack.top().second) {
// all the children have already been expanded.
// we can just run post order visit on it.
fvisit_leaf(node);
stack.pop();
} else if (const CallNode* op = node.as<CallNode>()) {
// mark expanded = true
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
fpush_to_stack(*it);
}
fpush_to_stack(op->op);
} else if (const TupleNode* op = node.as<TupleNode>()) {
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
fpush_to_stack(*it);
}
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
stack.top().second = true;
fpush_to_stack(op->tuple);
} else {
// No need to expand the children directly run visit.
fvisit_leaf(node);
stack.pop();
}
}
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_FUNCTOR_H_
68 changes: 0 additions & 68 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,74 +33,6 @@

namespace tvm {
namespace relay {
/*!
* \brief A function to iteratively traverse dataflow regions of a graph
*
* ExpandDataflow manually manages a stack and performs DFS to determine the processing
* order of nodes in an input graph.
*
* If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node
* need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack
* and continues iteratively to process the top of the stack. When it finds a node that doesn't
* match the dataflow types, or a node who's inputs have all been processed, it visits the current
* leaf via fvisit_leaf.
*
* This function should be used internally to other classes to implement mixed-mode traversals. The
* expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it
* hits a non-dataflow node.
*
* fcheck_visited and fvisit_leaf are templated to encourage compiler inlining.
*/
template <typename FCheckVisited, typename FVisitLeaf>
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
std::stack<std::pair<Expr, bool>> stack;
auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) {
// The second state of the stack indicate whether the child has been
// expanded in the pre-order.
// NOTE: function will be inlined.
if (!fcheck_visited(expr)) {
stack.push({expr, false});
}
};
fpush_to_stack(expr);
while (stack.size() > 0) {
auto node = stack.top().first;
if (fcheck_visited(node)) {
// if this node was visited through another path
// after being added to the stack ignore it.
stack.pop();
} else if (stack.top().second) {
// all the children have already been expanded.
// we can just run post order visit on it.
fvisit_leaf(node);
stack.pop();
} else if (const CallNode* op = node.as<CallNode>()) {
// mark expanded = true
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
fpush_to_stack(*it);
}
fpush_to_stack(op->op);
} else if (const TupleNode* op = node.as<TupleNode>()) {
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
fpush_to_stack(*it);
}
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
stack.top().second = true;
fpush_to_stack(op->tuple);
} else {
// No need to expand the children directly run visit.
fvisit_leaf(node);
stack.pop();
}
}
}

MixedModeVisitor::MixedModeVisitor(int visit_limit) {
ICHECK(visit_limit > 0) << "Dataflow visit limit must be greater than 0";
ICHECK(visit_limit < 10) << "Dataflow visit limit must be less than 10";
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/algorithm/topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TopKAttrs* param = attrs.as<TopKAttrs>();
ICHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
ICHECK(data);
if (data == nullptr) return false;
int ndim = data->shape.size();
int axis = param->axis;
if (axis < 0) {
Expand Down
51 changes: 43 additions & 8 deletions src/relay/transforms/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,37 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
TypeRelationFn tuple_getitem_rel_;
TypeRelationFn make_tuple_rel_;

/*! \brief Internal map used for memoization. */
std::unordered_map<Expr, Type, ObjectPtrHash, ObjectPtrEqual> memo_;

void VisitLeaf(const Expr& expr) {
if (!memo_.count(expr)) {
Type ret = this->DispatchVisitExpr(expr);
memo_[expr] = ret;
}
}

bool CheckVisited(const Expr& expr) {
if (memo_.count(expr)) {
return true;
} else {
return false;
}
}

Type DispatchVisitExpr(const Expr& expr) { return ExprFunctor::VisitExpr(expr); }

Type VisitExpr(const Expr& expr) final {
auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
if (memo_.count(expr)) {
return memo_[expr];
} else {
ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
return memo_[expr];
}
}

// Perform unification on two types and report the error at the expression
// or the span of the expression.
Type Unify(const Type& t1, const Type& t2, const Span& span) {
Expand Down Expand Up @@ -546,12 +577,14 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
}
};

class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator {
public:
Resolver(const std::unordered_map<Expr, ResolvedTypeInfo, ObjectPtrHash, ObjectPtrEqual>& tmap,
TypeSolver* solver)
: tmap_(tmap), solver_(solver) {}

using MixedModeMutator::VisitExpr_;

Expr VisitExpr_(const VarNode* op) final { return VisitVar(GetRef<Var>(op)); }

Expr VisitExpr_(const ConstantNode* op) final { return AttachCheckedType(op); }
Expand All @@ -560,13 +593,15 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {

Expr VisitExpr_(const OpNode* op) final { return ExprMutator::VisitExpr_(op); }

Expr VisitExpr_(const TupleNode* op) final { return AttachCheckedType(op); }
Expr Rewrite_(const TupleNode* op, const Expr& post) final { return AttachCheckedType(op, post); }

Expr VisitExpr_(const TupleGetItemNode* op) final { return AttachCheckedType(op); }
Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
return AttachCheckedType(op, post);
}

Expr VisitExpr_(const FunctionNode* op) final { return AttachCheckedType(op); }

Expr VisitExpr_(const CallNode* op) final { return AttachCheckedType(op); }
Expr Rewrite_(const CallNode* op, const Expr& post) final { return AttachCheckedType(op, post); }

Expr VisitExpr_(const LetNode* op) final { return AttachCheckedType(op); }

Expand All @@ -593,7 +628,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {

// attach checked type to the mutated node.
template <typename T>
Expr AttachCheckedType(const T* op) {
Expr AttachCheckedType(const T* op, const Expr& post = Expr()) {
auto it = tmap_.find(GetRef<Expr>(op));
ICHECK(it != tmap_.end());
Type checked_type = solver_->Resolve(it->second.checked_type);
Expand All @@ -606,7 +641,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
<< " check other reported errors for hints of what may of happened.");
}

Expr new_e = ExprMutator::VisitExpr_(op);
Expr new_e = post.defined() ? post : ExprMutator::VisitExpr_(op);
// new_call and new_var's code is only going to be valid for VarNode/CallNode.
// Compiler optimization will likely fold these away for other nodes.
CallNode* new_call = (std::is_base_of<CallNode, T>::value
Expand Down Expand Up @@ -702,8 +737,8 @@ Expr TypeInferencer::Infer(GlobalVar var, Function function) {
return resolved_expr;
}

struct AllCheckTypePopulated : ExprVisitor {
void VisitExpr(const Expr& e) {
struct AllCheckTypePopulated : MixedModeVisitor {
void DispatchExprVisit(const Expr& e) {
if (e.as<OpNode>()) {
return;
}
Expand Down