From f31847a910d054170b580e91430728887cdcc276 Mon Sep 17 00:00:00 2001 From: Li Xiaoquan Date: Sat, 17 Oct 2020 11:19:04 +0800 Subject: [PATCH] [Relay] Mix mode type inference --- include/tvm/relay/expr_functor.h | 68 ++++++++++++++++++++++++++++++ src/relay/ir/expr_functor.cc | 68 ------------------------------ src/relay/op/algorithm/topk.cc | 2 +- src/relay/transforms/type_infer.cc | 51 ++++++++++++++++++---- 4 files changed, 112 insertions(+), 77 deletions(-) diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index df0940fa7482..8589f8cc4f16 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -32,6 +32,7 @@ #include #include +#include #include #include #include @@ -408,6 +409,73 @@ Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter); */ void PostOrderVisit(const Expr& node, std::function fvisit); +/*! + * \brief A function to iteratively traverse dataflow regions of a graph + * + * ExpandDataflow manually manages a stack and performs DFS to determine the processing + * order of nodes in an input graph. + * + * If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node + * need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack + * and continues iteratively to process the top of the stack. When it finds a node that doesn't + * match the dataflow types, or a node who's inputs have all been processed, it visits the current + * leaf via fvisit_leaf. + * + * This function should be used internally to other classes to implement mixed-mode traversals. The + * expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it + * hits a non-dataflow node. + * + * fcheck_visited and fvisit_leaf are templated to encourage compiler inlining. + */ +template +void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) { + std::stack> stack; + auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) { + // The second state of the stack indicate whether the child has been + // expanded in the pre-order. + // NOTE: function will be inlined. + if (!fcheck_visited(expr)) { + stack.push({expr, false}); + } + }; + fpush_to_stack(expr); + while (stack.size() > 0) { + auto node = stack.top().first; + if (fcheck_visited(node)) { + // if this node was visited through another path + // after being added to the stack ignore it. + stack.pop(); + } else if (stack.top().second) { + // all the children have already been expanded. + // we can just run post order visit on it. + fvisit_leaf(node); + stack.pop(); + } else if (const CallNode* op = node.as()) { + // mark expanded = true + stack.top().second = true; + // push the children to the stack in reverse order + // to match recursive processing order + for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) { + fpush_to_stack(*it); + } + fpush_to_stack(op->op); + } else if (const TupleNode* op = node.as()) { + stack.top().second = true; + // push the children to the stack in reverse order + // to match recursive processing order + for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) { + fpush_to_stack(*it); + } + } else if (const TupleGetItemNode* op = node.as()) { + stack.top().second = true; + fpush_to_stack(op->tuple); + } else { + // No need to expand the children directly run visit. + fvisit_leaf(node); + stack.pop(); + } + } +} } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_FUNCTOR_H_ diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index a22b69c4ed1b..74095a753950 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -33,74 +33,6 @@ namespace tvm { namespace relay { -/*! - * \brief A function to iteratively traverse dataflow regions of a graph - * - * ExpandDataflow manually manages a stack and performs DFS to determine the processing - * order of nodes in an input graph. - * - * If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node - * need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack - * and continues iteratively to process the top of the stack. When it finds a node that doesn't - * match the dataflow types, or a node who's inputs have all been processed, it visits the current - * leaf via fvisit_leaf. - * - * This function should be used internally to other classes to implement mixed-mode traversals. The - * expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it - * hits a non-dataflow node. - * - * fcheck_visited and fvisit_leaf are templated to encourage compiler inlining. - */ -template -void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) { - std::stack> stack; - auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) { - // The second state of the stack indicate whether the child has been - // expanded in the pre-order. - // NOTE: function will be inlined. - if (!fcheck_visited(expr)) { - stack.push({expr, false}); - } - }; - fpush_to_stack(expr); - while (stack.size() > 0) { - auto node = stack.top().first; - if (fcheck_visited(node)) { - // if this node was visited through another path - // after being added to the stack ignore it. - stack.pop(); - } else if (stack.top().second) { - // all the children have already been expanded. - // we can just run post order visit on it. - fvisit_leaf(node); - stack.pop(); - } else if (const CallNode* op = node.as()) { - // mark expanded = true - stack.top().second = true; - // push the children to the stack in reverse order - // to match recursive processing order - for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) { - fpush_to_stack(*it); - } - fpush_to_stack(op->op); - } else if (const TupleNode* op = node.as()) { - stack.top().second = true; - // push the children to the stack in reverse order - // to match recursive processing order - for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) { - fpush_to_stack(*it); - } - } else if (const TupleGetItemNode* op = node.as()) { - stack.top().second = true; - fpush_to_stack(op->tuple); - } else { - // No need to expand the children directly run visit. - fvisit_leaf(node); - stack.pop(); - } - } -} - MixedModeVisitor::MixedModeVisitor(int visit_limit) { ICHECK(visit_limit > 0) << "Dataflow visit limit must be greater than 0"; ICHECK(visit_limit < 10) << "Dataflow visit limit must be less than 10"; diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc index b0e4b5dc6b4e..c1d3e5472743 100644 --- a/src/relay/op/algorithm/topk.cc +++ b/src/relay/op/algorithm/topk.cc @@ -36,7 +36,7 @@ bool TopKRel(const Array& types, int num_inputs, const Attrs& attrs, const TopKAttrs* param = attrs.as(); ICHECK_EQ(types.size(), 2); const auto* data = types[0].as(); - ICHECK(data); + if (data == nullptr) return false; int ndim = data->shape.size(); int axis = param->axis; if (axis < 0) { diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index cb3ba0030a5b..327b5d1e260a 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -129,6 +129,37 @@ class TypeInferencer : private ExprFunctor, TypeRelationFn tuple_getitem_rel_; TypeRelationFn make_tuple_rel_; + /*! \brief Internal map used for memoization. */ + std::unordered_map memo_; + + void VisitLeaf(const Expr& expr) { + if (!memo_.count(expr)) { + Type ret = this->DispatchVisitExpr(expr); + memo_[expr] = ret; + } + } + + bool CheckVisited(const Expr& expr) { + if (memo_.count(expr)) { + return true; + } else { + return false; + } + } + + Type DispatchVisitExpr(const Expr& expr) { return ExprFunctor::VisitExpr(expr); } + + Type VisitExpr(const Expr& expr) final { + auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); }; + auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); }; + if (memo_.count(expr)) { + return memo_[expr]; + } else { + ExpandDataflow(expr, fcheck_visited, fvisit_leaf); + return memo_[expr]; + } + } + // Perform unification on two types and report the error at the expression // or the span of the expression. Type Unify(const Type& t1, const Type& t2, const Span& span) { @@ -546,12 +577,14 @@ class TypeInferencer : private ExprFunctor, } }; -class TypeInferencer::Resolver : public ExprMutator, PatternMutator { +class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator { public: Resolver(const std::unordered_map& tmap, TypeSolver* solver) : tmap_(tmap), solver_(solver) {} + using MixedModeMutator::VisitExpr_; + Expr VisitExpr_(const VarNode* op) final { return VisitVar(GetRef(op)); } Expr VisitExpr_(const ConstantNode* op) final { return AttachCheckedType(op); } @@ -560,13 +593,15 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { Expr VisitExpr_(const OpNode* op) final { return ExprMutator::VisitExpr_(op); } - Expr VisitExpr_(const TupleNode* op) final { return AttachCheckedType(op); } + Expr Rewrite_(const TupleNode* op, const Expr& post) final { return AttachCheckedType(op, post); } - Expr VisitExpr_(const TupleGetItemNode* op) final { return AttachCheckedType(op); } + Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final { + return AttachCheckedType(op, post); + } Expr VisitExpr_(const FunctionNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const CallNode* op) final { return AttachCheckedType(op); } + Expr Rewrite_(const CallNode* op, const Expr& post) final { return AttachCheckedType(op, post); } Expr VisitExpr_(const LetNode* op) final { return AttachCheckedType(op); } @@ -593,7 +628,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { // attach checked type to the mutated node. template - Expr AttachCheckedType(const T* op) { + Expr AttachCheckedType(const T* op, const Expr& post = Expr()) { auto it = tmap_.find(GetRef(op)); ICHECK(it != tmap_.end()); Type checked_type = solver_->Resolve(it->second.checked_type); @@ -606,7 +641,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { << " check other reported errors for hints of what may of happened."); } - Expr new_e = ExprMutator::VisitExpr_(op); + Expr new_e = post.defined() ? post : ExprMutator::VisitExpr_(op); // new_call and new_var's code is only going to be valid for VarNode/CallNode. // Compiler optimization will likely fold these away for other nodes. CallNode* new_call = (std::is_base_of::value @@ -702,8 +737,8 @@ Expr TypeInferencer::Infer(GlobalVar var, Function function) { return resolved_expr; } -struct AllCheckTypePopulated : ExprVisitor { - void VisitExpr(const Expr& e) { +struct AllCheckTypePopulated : MixedModeVisitor { + void DispatchExprVisit(const Expr& e) { if (e.as()) { return; }