diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 68cef9483750..a26b30caf092 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -181,6 +181,7 @@ class ExprVisitor protected: // Internal visiting counter std::unordered_map visit_counter_; + bool enter_primitives_ = false; }; /*! @@ -231,6 +232,7 @@ class ExprMutator protected: /*! \brief Internal map used for memoization. */ std::unordered_map memo_; + bool enter_primitives_ = false; }; /*! diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index c525b9eb7324..72db441cb320 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -85,6 +85,13 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) { Expr ExprMutator::VisitExpr_(const FunctionNode* op) { tvm::Array ty_params; bool all_ty_params_unchanged = true; + // don't mutate primitive functions + const auto primitive = FunctionGetAttr(GetRef(op), attr::kPrimitive); + if (primitive->IsInstance()) { + if (primitive.as()->value && !enter_primitives_) { + return GetRef(op); + } + } for (auto ty_param : op->type_params) { TypeVar new_ty_param = Downcast(VisitType(ty_param)); @@ -256,7 +263,12 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) { for (auto param : op->params) { this->VisitExpr(param); } - + const auto primitive = FunctionGetAttr(GetRef(op), attr::kPrimitive); + if (primitive->IsInstance()) { + if (primitive.as()->value && !enter_primitives_) { + return; + } + } this->VisitExpr(op->body); } diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index ed5f91a3f1e0..c8465f0d0291 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -604,6 +604,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { Resolver(const std::unordered_map& tmap, TypeSolver* solver) : tmap_(tmap), solver_(solver) { + enter_primitives_ = true; } Expr VisitExpr_(const VarNode* op) final {