diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 8589f8cc4f16..d53658f87f40 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -476,6 +476,10 @@ void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_l } } } + +void ExpandANormalForm(const LetNode* op, std::function pre_visit, + std::function post_visit); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_FUNCTOR_H_ diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index abb9e6b034c2..90750575b9d4 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -141,6 +141,18 @@ class TypeVarEVisitor : private MixedModeVisitor { ExprVisitor::VisitExpr_(f); } + void VisitExpr_(const LetNode* op) final { + auto pre_visit = [this](const LetNode* op) { + this->VisitExpr(op->var); + this->VisitExpr(op->value); + }; + auto post_visit = [this](const LetNode* op) { + this->VisitExpr(op->body); + this->visit_counter_[op] += 1; + }; + ExpandANormalForm(op, pre_visit, post_visit); + } + void VisitExpr_(const ConstructorNode* cn) final { // for constructors, type vars will be bound in the module auto data = mod_->LookupTypeDef(cn->belong_to); diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 74095a753950..d70c6fe2dd1f 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -532,5 +532,27 @@ TVM_REGISTER_GLOBAL("relay.ir.Bind").set_body([](TVMArgs args, TVMRetValue* ret) *ret = Bind(Downcast(input), args[1]); } }); + +void ExpandANormalForm(const LetNode* op, std::function pre_visit, + std::function post_visit) { + std::stack stack; + stack.push(op); + bool is_anormal = true; + while (is_anormal) { + const LetNode* current_op = stack.top(); + pre_visit(current_op); + if (const LetNode* new_op = current_op->body.as()) { + stack.push(new_op); + } else { + is_anormal = false; + } + } + while (stack.size()) { + const LetNode* current_op = stack.top(); + stack.pop(); + post_visit(current_op); + } +} + } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index 43b71f6f10cc..2fd88736bf31 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -27,6 +27,8 @@ #include #include +#include + namespace tvm { namespace relay { @@ -61,8 +63,20 @@ Expr DeDup(const Expr& e) { } Expr VisitExpr_(const LetNode* op) final { - Var v = Fresh(op->var); - return Let(v, VisitExpr(op->value), VisitExpr(op->body)); + std::unordered_map new_vars; + auto pre_visit = [this, &new_vars](const LetNode* op) { + Expr expr = GetRef(op); + new_vars[expr] = this->Fresh(op->var); + // Rely on the Memoizer to cache pre-visit values + this->VisitExpr(op->value); + }; + auto post_visit = [this, &new_vars](const LetNode* op) { + Expr expr = GetRef(op); + this->memo_[expr] = + Let(new_vars[expr], this->VisitExpr(op->value), this->VisitExpr(op->body)); + }; + ExpandANormalForm(op, pre_visit, post_visit); + return memo_[GetRef(op)]; } Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; } @@ -99,7 +113,7 @@ Expr DeDup(const Expr& e) { ICHECK(WellFormed(ret)); ICHECK_EQ(FreeVars(e).size(), FreeVars(ret).size()); return ret; -} +} // namespace relay TVM_REGISTER_GLOBAL("relay._transform.dedup").set_body_typed(DeDup); diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 66f233bbba85..0689263cca77 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -92,19 +92,33 @@ class ConstantFolder : public MixedModeMutator { using MixedModeMutator::VisitExpr_; Expr VisitExpr_(const LetNode* op) final { - Expr value = this->Mutate(op->value); - if (value.as()) { - memo_[op->var] = value; - return this->Mutate(op->body); - } else { - Var var = Downcast(this->Mutate(op->var)); - Expr body = this->Mutate(op->body); - if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + auto pre_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + Expr value = this->Mutate(op->value); + if (value.as()) { + this->memo_[op->var] = value; } else { - return Let(var, value, body); + this->Mutate(op->var); } - } + }; + auto post_visit = [this](const LetNode* op) { + Expr expr = GetRef(op); + // Rely on the Memoizer to cache pre-visit values + Expr value = this->Mutate(op->value); + if (value.as()) { + this->memo_[expr] = this->Mutate(op->body); + } else { + Var var = Downcast(this->Mutate(op->var)); + Expr body = this->Mutate(op->body); + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + this->memo_[expr] = expr; + } else { + this->memo_[expr] = Let(var, value, body); + } + } + }; + ExpandANormalForm(op, pre_visit, post_visit); + return memo_[GetRef(op)]; } bool inside_primitive = false; diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 1b28980a0a2f..eaef0b905079 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -315,11 +315,20 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const LetNode* op) final { // do not fuse through let. - this->Update(op->var, nullptr, kOpaque); - this->Update(op->value, nullptr, kOpaque); - this->Update(op->body, nullptr, kOpaque); - ExprVisitor::VisitExpr_(op); - this->AddNode(op); + auto pre_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + this->Update(op->var, nullptr, kOpaque); + this->Update(op->value, nullptr, kOpaque); + this->Update(op->body, nullptr, kOpaque); + this->VisitExpr(op->var); + this->VisitExpr(op->value); + }; + auto post_visit = [this](const LetNode* op) { + this->VisitExpr(op->body); + this->visit_counter_[op] += 1; + this->AddNode(op); + }; + ExpandANormalForm(op, pre_visit, post_visit); } void VisitExpr_(const IfNode* op) final { @@ -797,7 +806,7 @@ std::vector GraphPartitioner::Partition( return std::move(groups_); } -class FuseMutator : private ExprMutator { +class FuseMutator : private MixedModeMutator { public: // Run the transform Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth) { @@ -814,6 +823,8 @@ class FuseMutator : private ExprMutator { } private: + using MixedModeMutator::VisitExpr_; + /*! \brief Temporary information from each group. */ struct GroupInfo { public: @@ -853,7 +864,7 @@ class FuseMutator : private ExprMutator { } // Transform calls. - Expr VisitExpr_(const CallNode* call) { + Expr Rewrite_(const CallNode* call, const Expr& post) { if (call->op.as()) { static auto fnoncomputational = Op::GetAttrMap("TNonComputational"); @@ -886,7 +897,7 @@ class FuseMutator : private ExprMutator { } } - Expr VisitExpr_(const TupleNode* tuple) { + Expr Rewrite_(const TupleNode* tuple, const Expr& post) { auto* ret_group = gmap_.at(tuple)->FindRoot(); if (ret_group->root_ref == tuple) { return ExprMutator::VisitExpr_(tuple); @@ -896,7 +907,7 @@ class FuseMutator : private ExprMutator { return Tuple(new_fields); } - Expr VisitExpr_(const TupleGetItemNode* tuple_get) { + Expr Rewrite_(const TupleGetItemNode* tuple_get, const Expr& post) { auto* ret_group = gmap_.at(tuple_get)->FindRoot(); auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0]; auto new_node = TupleGetItem(new_tuple, tuple_get->index); @@ -913,6 +924,29 @@ class FuseMutator : private ExprMutator { return std::move(new_node); } + Expr VisitExpr_(const LetNode* op) final { + auto pre_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + this->VisitExpr(op->var); + this->VisitExpr(op->value); + }; + auto post_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + Var var = Downcast(this->VisitExpr(op->var)); + Expr value = this->VisitExpr(op->value); + // Visit body and cache the op + Expr body = this->VisitExpr(op->body); + auto expr = GetRef(op); + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + this->memo_[expr] = expr; + } else { + this->memo_[expr] = Let(var, value, body); + } + }; + ExpandANormalForm(op, pre_visit, post_visit); + return memo_[GetRef(op)]; + } + Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) { // If the function has no call, it is not a primitive function. struct HasCallVisitor : ExprVisitor { diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 921e83fdb092..b4ccd1659865 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -341,26 +341,34 @@ class TypeInferencer : private ExprFunctor, Type VisitExpr_(const OpNode* op) final { return op->op_type; } Type VisitExpr_(const LetNode* let) final { - // if the definition is a function literal, permit recursion - bool is_functional_literal = let->value.as() != nullptr; - Type let_type = IncompleteType(Kind::kType); - - if (is_functional_literal) { - let_type = GetType(let->var); - type_map_[let->var].checked_type = let_type; - } + auto pre_visit = [this](const LetNode* op) { + // if the definition is a function literal, permit recursion + bool is_functional_literal = op->value.as() != nullptr; + Type let_type = IncompleteType(Kind::kType); + + if (is_functional_literal) { + let_type = this->GetType(op->var); + this->type_map_[op->var].checked_type = let_type; + } - if (let->var->type_annotation.defined()) { - let_type = Unify(let_type, let->var->type_annotation, let->span); - } + if (op->var->type_annotation.defined()) { + let_type = this->Unify(let_type, op->var->type_annotation, op->span); + } - Type vtype = GetType(let->value); - let_type = Unify(let_type, vtype, let->span); + Type vtype = this->GetType(op->value); + let_type = this->Unify(let_type, vtype, op->span); - ICHECK(is_functional_literal || !type_map_.count(let->var)); - // NOTE: no scoping is necessary because var are unique in program - type_map_[let->var].checked_type = let_type; - return GetType(let->body); + ICHECK(is_functional_literal || !this->type_map_.count(op->var)); + // NOTE: no scoping is necessary because var are unique in program + this->type_map_[op->var].checked_type = let_type; + }; + auto post_visit = [this](const LetNode* op) { + Expr expr = GetRef(op); + this->memo_[expr] = this->GetType(op->body); + this->type_map_[expr].checked_type = this->memo_[expr]; + }; + ExpandANormalForm(let, pre_visit, post_visit); + return memo_[GetRef(let)]; } Type VisitExpr_(const IfNode* ite) final { @@ -603,7 +611,21 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator { Expr Rewrite_(const CallNode* op, const Expr& post) final { return AttachCheckedType(op, post); } - Expr VisitExpr_(const LetNode* op) final { return AttachCheckedType(op); } + Expr VisitExpr_(const LetNode* op) final { + auto pre_visit = [this](const LetNode* op) { + this->VisitExpr(op->var); + this->VisitExpr(op->value); + }; + auto post_visit = [this](const LetNode* op) { + Expr expr = GetRef(op); + Var var = Downcast(this->VisitExpr(op->var)); + Expr value = this->VisitExpr(op->value); + Expr body = this->VisitExpr(op->body); + this->memo_[expr] = this->AttachCheckedType(op, Let(var, value, body)); + }; + ExpandANormalForm(op, pre_visit, post_visit); + return memo_[GetRef(op)]; + } Expr VisitExpr_(const IfNode* op) final { return AttachCheckedType(op); } @@ -738,6 +760,7 @@ Expr TypeInferencer::Infer(GlobalVar var, Function function) { } struct AllCheckTypePopulated : MixedModeVisitor { + using MixedModeVisitor::VisitExpr_; void DispatchExprVisit(const Expr& e) { if (e.as()) { return; @@ -751,6 +774,17 @@ struct AllCheckTypePopulated : MixedModeVisitor { ICHECK(e->checked_type_.defined()) << "Expression: " << e; return ExprVisitor::VisitExpr(e); } + void VisitExpr_(const LetNode* op) final { + auto pre_visit = [this](const LetNode* op) { + this->VisitExpr(op->var); + this->VisitExpr(op->value); + }; + auto post_visit = [this](const LetNode* op) { + this->VisitExpr(op->body); + this->visit_counter_[op] += 1; + }; + ExpandANormalForm(op, pre_visit, post_visit); + } }; void EnsureCheckedType(const Expr& e) { AllCheckTypePopulated().VisitExpr(e); }