From f4f60f27d9073b9deb766188df8a6b85b5427bc1 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Mon, 3 Feb 2020 12:55:10 +0000 Subject: [PATCH] [Relay] Ignore Primitive functions in Visitors Primitive functions are not to be modified by passes and as such should be ignored. This commit makes Visitor patterns ignore Primitive functions by default, although this can be overridden by setting enter_primitives=true. One case in which this overriding is necessary is in the type_infer pass. Change-Id: Ib9cf7f44d76dd929617493369b6a9912134085b4 --- include/tvm/relay/expr_functor.h | 2 ++ src/relay/ir/expr_functor.cc | 14 +++++++++++++- src/relay/pass/type_infer.cc | 1 + 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 68cef9483750..a26b30caf092 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -181,6 +181,7 @@ class ExprVisitor protected: // Internal visiting counter std::unordered_map visit_counter_; + bool enter_primitives_ = false; }; /*! @@ -231,6 +232,7 @@ class ExprMutator protected: /*! \brief Internal map used for memoization. */ std::unordered_map memo_; + bool enter_primitives_ = false; }; /*! diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index c525b9eb7324..72db441cb320 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -85,6 +85,13 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) { Expr ExprMutator::VisitExpr_(const FunctionNode* op) { tvm::Array ty_params; bool all_ty_params_unchanged = true; + // don't mutate primitive functions + const auto primitive = FunctionGetAttr(GetRef(op), attr::kPrimitive); + if (primitive->IsInstance()) { + if (primitive.as()->value && !enter_primitives_) { + return GetRef(op); + } + } for (auto ty_param : op->type_params) { TypeVar new_ty_param = Downcast(VisitType(ty_param)); @@ -256,7 +263,12 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) { for (auto param : op->params) { this->VisitExpr(param); } - + const auto primitive = FunctionGetAttr(GetRef(op), attr::kPrimitive); + if (primitive->IsInstance()) { + if (primitive.as()->value && !enter_primitives_) { + return; + } + } this->VisitExpr(op->body); } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index ed5f91a3f1e0..c8465f0d0291 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -604,6 +604,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { Resolver(const std::unordered_map& tmap, TypeSolver* solver) : tmap_(tmap), solver_(solver) { + enter_primitives_ = true; } Expr VisitExpr_(const VarNode* op) final {