From 9ceb0363022e984fa39c6db794e9564ca6c94890 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Fri, 29 Jan 2021 14:53:42 -0700 Subject: [PATCH 1/5] [WIP][Relay][Passes] non-recursive a-normal traversals --- src/relay/analysis/util.cc | 22 ++++++ src/relay/transforms/de_duplicate.cc | 30 +++++++- src/relay/transforms/fold_constant.cc | 49 +++++++++--- src/relay/transforms/fuse_ops.cc | 70 ++++++++++++++--- src/relay/transforms/type_infer.cc | 104 +++++++++++++++++++++----- 5 files changed, 236 insertions(+), 39 deletions(-) diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index abb9e6b034c2..fa989ebfe42c 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -141,6 +141,28 @@ class TypeVarEVisitor : private MixedModeVisitor { ExprVisitor::VisitExpr_(f); } + void VisitExpr_(const LetNode* op) final { + std::stack stack; + stack.push(op); + bool is_anormal = true; + while (is_anormal) { + const LetNode* current_op = stack.top(); + VisitExpr(current_op->var); + VisitExpr(current_op->value); + if (const LetNode* new_op = current_op->body.as()) { + stack.push(new_op); + } else { + is_anormal = false; + } + } + while (stack.size()) { + const LetNode* current_op = stack.top(); + stack.pop(); + VisitExpr(current_op->body); + visit_counter_[current_op] += 1; + } + } + void VisitExpr_(const ConstructorNode* cn) final { // for constructors, type vars will be bound in the module auto data = mod_->LookupTypeDef(cn->belong_to); diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index 43b71f6f10cc..4a16cb724e3a 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -27,6 +27,8 @@ #include #include +#include + namespace tvm { namespace relay { @@ -61,8 +63,30 @@ Expr DeDup(const Expr& e) { } Expr VisitExpr_(const LetNode* op) final { - Var v = Fresh(op->var); - return Let(v, VisitExpr(op->value), VisitExpr(op->body)); + std::unordered_map new_vars; + std::unordered_map new_values; + std::stack stack; + stack.push(op); + bool is_anormal = true; + while (is_anormal) { + const LetNode* current_op = stack.top(); + Expr current_expr = GetRef(current_op); + new_vars[current_expr] = Fresh(current_op->var); + new_values[current_expr] = VisitExpr(current_op->value); + if (const LetNode* new_op = current_op->body.as()) { + stack.push(new_op); + } else { + is_anormal = false; + } + } + while (stack.size()) { + const LetNode* current_op = stack.top(); + Expr current_expr = GetRef(current_op); + stack.pop(); + memo_[current_expr] = + Let(new_vars[current_expr], new_values[current_expr], VisitExpr(current_op->body)); + } + return memo_[GetRef(op)]; } Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; } @@ -99,7 +123,7 @@ Expr DeDup(const Expr& e) { ICHECK(WellFormed(ret)); ICHECK_EQ(FreeVars(e).size(), FreeVars(ret).size()); return ret; -} +} // namespace relay TVM_REGISTER_GLOBAL("relay._transform.dedup").set_body_typed(DeDup); diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 66f233bbba85..705948c3a612 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -92,19 +92,48 @@ class ConstantFolder : public MixedModeMutator { using MixedModeMutator::VisitExpr_; Expr VisitExpr_(const LetNode* op) final { - Expr value = this->Mutate(op->value); - if (value.as()) { - memo_[op->var] = value; - return this->Mutate(op->body); - } else { - Var var = Downcast(this->Mutate(op->var)); - Expr body = this->Mutate(op->body); - if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); + std::unordered_map new_vars; + std::unordered_map new_values; + std::stack stack; + stack.push(op); + bool is_anormal = true; + while (is_anormal) { + const LetNode* current_op = stack.top(); + Expr current_expr = GetRef(current_op); + + Expr value = this->Mutate(current_op->value); + new_values[current_expr] = value; + if (value.as()) { + memo_[current_op->var] = value; } else { - return Let(var, value, body); + new_vars[current_expr] = Downcast(this->Mutate(current_op->var)); + } + + if (const LetNode* new_op = current_op->body.as()) { + stack.push(new_op); + } else { + is_anormal = false; + } + } + while (stack.size()) { + const LetNode* current_op = stack.top(); + Expr current_expr = GetRef(current_op); + stack.pop(); + Expr value = new_values[current_expr]; + if (value.as()) { + memo_[current_expr] = this->Mutate(current_op->body); + } else { + Var var = new_vars[current_expr]; + Expr body = this->Mutate(current_op->body); + if (var.same_as(current_op->var) && value.same_as(current_op->value) && + body.same_as(current_op->body)) { + memo_[current_expr] = current_expr; + } else { + memo_[current_expr] = Let(var, value, body); + } } } + return memo_[GetRef(op)]; } bool inside_primitive = false; diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 1b28980a0a2f..e4f43faf07c0 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -315,11 +315,29 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const LetNode* op) final { // do not fuse through let. - this->Update(op->var, nullptr, kOpaque); - this->Update(op->value, nullptr, kOpaque); - this->Update(op->body, nullptr, kOpaque); - ExprVisitor::VisitExpr_(op); - this->AddNode(op); + std::stack stack; + stack.push(op); + bool is_anormal = true; + while (is_anormal) { + const LetNode* current_op = stack.top(); + this->Update(current_op->var, nullptr, kOpaque); + this->Update(current_op->value, nullptr, kOpaque); + this->Update(current_op->body, nullptr, kOpaque); + VisitExpr(current_op->var); + VisitExpr(current_op->value); + if (const LetNode* new_op = current_op->body.as()) { + stack.push(new_op); + } else { + is_anormal = false; + } + } + while (stack.size()) { + const LetNode* current_op = stack.top(); + stack.pop(); + VisitExpr(current_op->body); + visit_counter_[current_op] += 1; + this->AddNode(current_op); + } } void VisitExpr_(const IfNode* op) final { @@ -797,7 +815,7 @@ std::vector GraphPartitioner::Partition( return std::move(groups_); } -class FuseMutator : private ExprMutator { +class FuseMutator : private MixedModeMutator { public: // Run the transform Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth) { @@ -853,7 +871,7 @@ class FuseMutator : private ExprMutator { } // Transform calls. - Expr VisitExpr_(const CallNode* call) { + Expr Rewrite_(const CallNode* call, const Expr& post) { if (call->op.as()) { static auto fnoncomputational = Op::GetAttrMap("TNonComputational"); @@ -886,7 +904,7 @@ class FuseMutator : private ExprMutator { } } - Expr VisitExpr_(const TupleNode* tuple) { + Expr Rewrite_(const TupleNode* tuple, const Expr& post) { auto* ret_group = gmap_.at(tuple)->FindRoot(); if (ret_group->root_ref == tuple) { return ExprMutator::VisitExpr_(tuple); @@ -896,7 +914,7 @@ class FuseMutator : private ExprMutator { return Tuple(new_fields); } - Expr VisitExpr_(const TupleGetItemNode* tuple_get) { + Expr Rewrite_(const TupleGetItemNode* tuple_get, const Expr& post) { auto* ret_group = gmap_.at(tuple_get)->FindRoot(); auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0]; auto new_node = TupleGetItem(new_tuple, tuple_get->index); @@ -913,6 +931,40 @@ class FuseMutator : private ExprMutator { return std::move(new_node); } + Expr VisitExpr_(const LetNode* op) final { + std::unordered_map new_vars; + std::unordered_map new_values; + std::stack stack; + stack.push(op); + bool is_anormal = true; + while (is_anormal) { + const LetNode* current_op = stack.top(); + Expr current_expr = GetRef(current_op); + new_vars[current_expr] = Downcast(VisitExpr(current_op->var)); + new_values[current_expr] = VisitExpr(current_op->value); + if (const LetNode* new_op = current_op->body.as()) { + stack.push(new_op); + } else { + is_anormal = false; + } + } + while (stack.size()) { + const LetNode* current_op = stack.top(); + Expr current_expr = GetRef(current_op); + stack.pop(); + Var var = new_vars[current_expr]; + Expr value = new_values[current_expr]; + Expr body = VisitExpr(current_op->body); + if (var.same_as(current_op->var) && value.same_as(current_op->value) && + body.same_as(current_op->body)) { + memo_[current_expr] = current_expr; + } else { + memo_[current_expr] = Let(var, value, body); + } + } + return memo_[GetRef(op)]; + } + Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) { // If the function has no call, it is not a primitive function. struct HasCallVisitor : ExprVisitor { diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 921e83fdb092..c37f5343628d 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -341,26 +341,47 @@ class TypeInferencer : private ExprFunctor, Type VisitExpr_(const OpNode* op) final { return op->op_type; } Type VisitExpr_(const LetNode* let) final { - // if the definition is a function literal, permit recursion - bool is_functional_literal = let->value.as() != nullptr; - Type let_type = IncompleteType(Kind::kType); + std::stack stack; + stack.push(let); + bool is_anormal = true; + while (is_anormal) { + const LetNode* current_op = stack.top(); + const Expr current_expr = GetRef(current_op); + // if the definition is a function literal, permit recursion + bool is_functional_literal = current_op->value.as() != nullptr; + Type let_type = IncompleteType(Kind::kType); + + if (is_functional_literal) { + let_type = GetType(current_op->var); + type_map_[current_op->var].checked_type = let_type; + } - if (is_functional_literal) { - let_type = GetType(let->var); - type_map_[let->var].checked_type = let_type; - } + if (current_op->var->type_annotation.defined()) { + let_type = Unify(let_type, current_op->var->type_annotation, current_op->span); + } - if (let->var->type_annotation.defined()) { - let_type = Unify(let_type, let->var->type_annotation, let->span); - } + Type vtype = GetType(current_op->value); + let_type = Unify(let_type, vtype, current_op->span); + + ICHECK(is_functional_literal || !type_map_.count(current_op->var)); + // NOTE: no scoping is necessary because var are unique in program + type_map_[current_op->var].checked_type = let_type; - Type vtype = GetType(let->value); - let_type = Unify(let_type, vtype, let->span); + if (const LetNode* new_op = current_op->body.as()) { + stack.push(new_op); + } else { + is_anormal = false; + } + } + while (stack.size()) { + const LetNode* current_op = stack.top(); + Expr current_expr = GetRef(current_op); + stack.pop(); + memo_[current_expr] = GetType(current_op->body); + type_map_[current_expr].checked_type = memo_[current_expr]; + } - ICHECK(is_functional_literal || !type_map_.count(let->var)); - // NOTE: no scoping is necessary because var are unique in program - type_map_[let->var].checked_type = let_type; - return GetType(let->body); + return memo_[GetRef(let)]; } Type VisitExpr_(const IfNode* ite) final { @@ -603,7 +624,35 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator { Expr Rewrite_(const CallNode* op, const Expr& post) final { return AttachCheckedType(op, post); } - Expr VisitExpr_(const LetNode* op) final { return AttachCheckedType(op); } + Expr VisitExpr_(const LetNode* op) final { + std::unordered_map new_vars; + std::unordered_map new_values; + std::stack stack; + stack.push(op); + bool is_anormal = true; + while (is_anormal) { + const LetNode* current_op = stack.top(); + const Expr current_expr = GetRef(current_op); + new_vars[current_expr] = Downcast(VisitExpr(current_op->var)); + new_values[current_expr] = VisitExpr(current_op->value); + if (const LetNode* new_op = current_op->body.as()) { + stack.push(new_op); + } else { + is_anormal = false; + } + } + while (stack.size()) { + const LetNode* current_op = stack.top(); + Expr current_expr = GetRef(current_op); + stack.pop(); + Var var = new_vars[current_expr]; + Expr value = new_values[current_expr]; + Expr body = VisitExpr(current_op->body); + memo_[current_expr] = AttachCheckedType(current_op, Let(var, value, body)); + } + + return memo_[GetRef(op)]; + } Expr VisitExpr_(const IfNode* op) final { return AttachCheckedType(op); } @@ -751,6 +800,27 @@ struct AllCheckTypePopulated : MixedModeVisitor { ICHECK(e->checked_type_.defined()) << "Expression: " << e; return ExprVisitor::VisitExpr(e); } + void VisitExpr_(const LetNode* op) final { + std::stack stack; + stack.push(op); + bool is_anormal = true; + while (is_anormal) { + const LetNode* current_op = stack.top(); + VisitExpr(current_op->var); + VisitExpr(current_op->value); + if (const LetNode* new_op = current_op->body.as()) { + stack.push(new_op); + } else { + is_anormal = false; + } + } + while (stack.size()) { + const LetNode* current_op = stack.top(); + stack.pop(); + VisitExpr(current_op->body); + visit_counter_[current_op] += 1; + } + } }; void EnsureCheckedType(const Expr& e) { AllCheckTypePopulated().VisitExpr(e); } From 29e8605e0a4b07c237d555286f2f85d6a855e51c Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Fri, 29 Jan 2021 19:45:45 -0700 Subject: [PATCH 2/5] fix clang warning --- src/relay/transforms/fuse_ops.cc | 2 ++ src/relay/transforms/type_infer.cc | 1 + 2 files changed, 3 insertions(+) diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index e4f43faf07c0..4420aec6416b 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -832,6 +832,8 @@ class FuseMutator : private MixedModeMutator { } private: + using MixedModeMutator::VisitExpr_; + /*! \brief Temporary information from each group. */ struct GroupInfo { public: diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index c37f5343628d..4513bf242e09 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -787,6 +787,7 @@ Expr TypeInferencer::Infer(GlobalVar var, Function function) { } struct AllCheckTypePopulated : MixedModeVisitor { + using MixedModeVisitor::VisitExpr_; void DispatchExprVisit(const Expr& e) { if (e.as()) { return; From a2ce628effa08a87b0b9ac05bf2a476c7669f97c Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Mon, 1 Feb 2021 12:57:14 -0700 Subject: [PATCH 3/5] Refactor ANormal Iterative traversal into a higher order function utility with lambdas --- include/tvm/relay/expr_functor.h | 4 + src/relay/analysis/util.cc | 28 ++----- src/relay/ir/expr_functor.cc | 22 +++++ src/relay/transforms/de_duplicate.cc | 33 +++----- src/relay/transforms/fold_constant.cc | 51 ++++-------- src/relay/transforms/fuse_ops.cc | 45 ++++------ src/relay/transforms/type_infer.cc | 113 +++++++++----------------- 7 files changed, 119 insertions(+), 177 deletions(-) diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 8589f8cc4f16..d53658f87f40 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -476,6 +476,10 @@ void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_l } } } + +void ExpandANormalForm(const LetNode* op, std::function pre_visit, + std::function post_visit); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_FUNCTOR_H_ diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index fa989ebfe42c..90750575b9d4 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -142,25 +142,15 @@ class TypeVarEVisitor : private MixedModeVisitor { } void VisitExpr_(const LetNode* op) final { - std::stack stack; - stack.push(op); - bool is_anormal = true; - while (is_anormal) { - const LetNode* current_op = stack.top(); - VisitExpr(current_op->var); - VisitExpr(current_op->value); - if (const LetNode* new_op = current_op->body.as()) { - stack.push(new_op); - } else { - is_anormal = false; - } - } - while (stack.size()) { - const LetNode* current_op = stack.top(); - stack.pop(); - VisitExpr(current_op->body); - visit_counter_[current_op] += 1; - } + auto pre_visit = [this](const LetNode* op) { + this->VisitExpr(op->var); + this->VisitExpr(op->value); + }; + auto post_visit = [this](const LetNode* op) { + this->VisitExpr(op->body); + this->visit_counter_[op] += 1; + }; + ExpandANormalForm(op, pre_visit, post_visit); } void VisitExpr_(const ConstructorNode* cn) final { diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 74095a753950..d70c6fe2dd1f 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -532,5 +532,27 @@ TVM_REGISTER_GLOBAL("relay.ir.Bind").set_body([](TVMArgs args, TVMRetValue* ret) *ret = Bind(Downcast(input), args[1]); } }); + +void ExpandANormalForm(const LetNode* op, std::function pre_visit, + std::function post_visit) { + std::stack stack; + stack.push(op); + bool is_anormal = true; + while (is_anormal) { + const LetNode* current_op = stack.top(); + pre_visit(current_op); + if (const LetNode* new_op = current_op->body.as()) { + stack.push(new_op); + } else { + is_anormal = false; + } + } + while (stack.size()) { + const LetNode* current_op = stack.top(); + stack.pop(); + post_visit(current_op); + } +} + } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index 4a16cb724e3a..1cddc55981b3 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -64,28 +64,17 @@ Expr DeDup(const Expr& e) { Expr VisitExpr_(const LetNode* op) final { std::unordered_map new_vars; - std::unordered_map new_values; - std::stack stack; - stack.push(op); - bool is_anormal = true; - while (is_anormal) { - const LetNode* current_op = stack.top(); - Expr current_expr = GetRef(current_op); - new_vars[current_expr] = Fresh(current_op->var); - new_values[current_expr] = VisitExpr(current_op->value); - if (const LetNode* new_op = current_op->body.as()) { - stack.push(new_op); - } else { - is_anormal = false; - } - } - while (stack.size()) { - const LetNode* current_op = stack.top(); - Expr current_expr = GetRef(current_op); - stack.pop(); - memo_[current_expr] = - Let(new_vars[current_expr], new_values[current_expr], VisitExpr(current_op->body)); - } + auto pre_visit = [this, &new_vars](const LetNode* op) { + Expr expr = GetRef(op); + new_vars[expr] = Fresh(op->var); + // Rely on the Memoizer to cache pre-visit values + VisitExpr(op->value); + }; + auto post_visit = [this, &new_vars](const LetNode* op) { + Expr expr = GetRef(op); + memo_[expr] = Let(new_vars[expr], VisitExpr(op->value), VisitExpr(op->body)); + }; + ExpandANormalForm(op, pre_visit, post_visit); return memo_[GetRef(op)]; } diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 705948c3a612..5ea01fc698ca 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -92,47 +92,32 @@ class ConstantFolder : public MixedModeMutator { using MixedModeMutator::VisitExpr_; Expr VisitExpr_(const LetNode* op) final { - std::unordered_map new_vars; - std::unordered_map new_values; - std::stack stack; - stack.push(op); - bool is_anormal = true; - while (is_anormal) { - const LetNode* current_op = stack.top(); - Expr current_expr = GetRef(current_op); - - Expr value = this->Mutate(current_op->value); - new_values[current_expr] = value; + auto pre_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + Expr value = this->Mutate(op->value); if (value.as()) { - memo_[current_op->var] = value; + memo_[op->var] = value; } else { - new_vars[current_expr] = Downcast(this->Mutate(current_op->var)); + this->Mutate(op->var); } - - if (const LetNode* new_op = current_op->body.as()) { - stack.push(new_op); - } else { - is_anormal = false; - } - } - while (stack.size()) { - const LetNode* current_op = stack.top(); - Expr current_expr = GetRef(current_op); - stack.pop(); - Expr value = new_values[current_expr]; + }; + auto post_visit = [this](const LetNode* op) { + Expr expr = GetRef(op); + // Rely on the Memoizer to cache pre-visit values + Expr value = this->Mutate(op->value); if (value.as()) { - memo_[current_expr] = this->Mutate(current_op->body); + memo_[expr] = this->Mutate(op->body); } else { - Var var = new_vars[current_expr]; - Expr body = this->Mutate(current_op->body); - if (var.same_as(current_op->var) && value.same_as(current_op->value) && - body.same_as(current_op->body)) { - memo_[current_expr] = current_expr; + Var var = Downcast(this->Mutate(op->var)); + Expr body = this->Mutate(op->body); + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + memo_[expr] = expr; } else { - memo_[current_expr] = Let(var, value, body); + memo_[expr] = Let(var, value, body); } } - } + }; + ExpandANormalForm(op, pre_visit, post_visit); return memo_[GetRef(op)]; } diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 4420aec6416b..09ac785dcc02 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -934,36 +934,25 @@ class FuseMutator : private MixedModeMutator { } Expr VisitExpr_(const LetNode* op) final { - std::unordered_map new_vars; - std::unordered_map new_values; - std::stack stack; - stack.push(op); - bool is_anormal = true; - while (is_anormal) { - const LetNode* current_op = stack.top(); - Expr current_expr = GetRef(current_op); - new_vars[current_expr] = Downcast(VisitExpr(current_op->var)); - new_values[current_expr] = VisitExpr(current_op->value); - if (const LetNode* new_op = current_op->body.as()) { - stack.push(new_op); - } else { - is_anormal = false; - } - } - while (stack.size()) { - const LetNode* current_op = stack.top(); - Expr current_expr = GetRef(current_op); - stack.pop(); - Var var = new_vars[current_expr]; - Expr value = new_values[current_expr]; - Expr body = VisitExpr(current_op->body); - if (var.same_as(current_op->var) && value.same_as(current_op->value) && - body.same_as(current_op->body)) { - memo_[current_expr] = current_expr; + auto pre_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + this->VisitExpr(op->var); + this->VisitExpr(op->value); + }; + auto post_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + Var var = Downcast(VisitExpr(op->var)); + Expr value = VisitExpr(op->value); + // Visit body and cache the op + Expr body = VisitExpr(op->body); + auto expr = GetRef(op); + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + memo_[expr] = expr; } else { - memo_[current_expr] = Let(var, value, body); + memo_[expr] = Let(var, value, body); } - } + }; + ExpandANormalForm(op, pre_visit, post_visit); return memo_[GetRef(op)]; } diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 4513bf242e09..560452d055da 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -341,46 +341,33 @@ class TypeInferencer : private ExprFunctor, Type VisitExpr_(const OpNode* op) final { return op->op_type; } Type VisitExpr_(const LetNode* let) final { - std::stack stack; - stack.push(let); - bool is_anormal = true; - while (is_anormal) { - const LetNode* current_op = stack.top(); - const Expr current_expr = GetRef(current_op); + auto pre_visit = [this](const LetNode* op) { // if the definition is a function literal, permit recursion - bool is_functional_literal = current_op->value.as() != nullptr; + bool is_functional_literal = op->value.as() != nullptr; Type let_type = IncompleteType(Kind::kType); if (is_functional_literal) { - let_type = GetType(current_op->var); - type_map_[current_op->var].checked_type = let_type; + let_type = GetType(op->var); + type_map_[op->var].checked_type = let_type; } - if (current_op->var->type_annotation.defined()) { - let_type = Unify(let_type, current_op->var->type_annotation, current_op->span); + if (op->var->type_annotation.defined()) { + let_type = Unify(let_type, op->var->type_annotation, op->span); } - Type vtype = GetType(current_op->value); - let_type = Unify(let_type, vtype, current_op->span); + Type vtype = GetType(op->value); + let_type = Unify(let_type, vtype, op->span); - ICHECK(is_functional_literal || !type_map_.count(current_op->var)); + ICHECK(is_functional_literal || !type_map_.count(op->var)); // NOTE: no scoping is necessary because var are unique in program - type_map_[current_op->var].checked_type = let_type; - - if (const LetNode* new_op = current_op->body.as()) { - stack.push(new_op); - } else { - is_anormal = false; - } - } - while (stack.size()) { - const LetNode* current_op = stack.top(); - Expr current_expr = GetRef(current_op); - stack.pop(); - memo_[current_expr] = GetType(current_op->body); - type_map_[current_expr].checked_type = memo_[current_expr]; - } - + type_map_[op->var].checked_type = let_type; + }; + auto post_visit = [this](const LetNode* op) { + Expr expr = GetRef(op); + memo_[expr] = GetType(op->body); + type_map_[expr].checked_type = memo_[expr]; + }; + ExpandANormalForm(let, pre_visit, post_visit); return memo_[GetRef(let)]; } @@ -625,32 +612,18 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator { Expr Rewrite_(const CallNode* op, const Expr& post) final { return AttachCheckedType(op, post); } Expr VisitExpr_(const LetNode* op) final { - std::unordered_map new_vars; - std::unordered_map new_values; - std::stack stack; - stack.push(op); - bool is_anormal = true; - while (is_anormal) { - const LetNode* current_op = stack.top(); - const Expr current_expr = GetRef(current_op); - new_vars[current_expr] = Downcast(VisitExpr(current_op->var)); - new_values[current_expr] = VisitExpr(current_op->value); - if (const LetNode* new_op = current_op->body.as()) { - stack.push(new_op); - } else { - is_anormal = false; - } - } - while (stack.size()) { - const LetNode* current_op = stack.top(); - Expr current_expr = GetRef(current_op); - stack.pop(); - Var var = new_vars[current_expr]; - Expr value = new_values[current_expr]; - Expr body = VisitExpr(current_op->body); - memo_[current_expr] = AttachCheckedType(current_op, Let(var, value, body)); - } - + auto pre_visit = [this](const LetNode* op) { + this->VisitExpr(op->var); + this->VisitExpr(op->value); + }; + auto post_visit = [this](const LetNode* op) { + Expr expr = GetRef(op); + Var var = Downcast(VisitExpr(op->var)); + Expr value = VisitExpr(op->value); + Expr body = VisitExpr(op->body); + memo_[expr] = AttachCheckedType(op, Let(var, value, body)); + }; + ExpandANormalForm(op, pre_visit, post_visit); return memo_[GetRef(op)]; } @@ -802,25 +775,15 @@ struct AllCheckTypePopulated : MixedModeVisitor { return ExprVisitor::VisitExpr(e); } void VisitExpr_(const LetNode* op) final { - std::stack stack; - stack.push(op); - bool is_anormal = true; - while (is_anormal) { - const LetNode* current_op = stack.top(); - VisitExpr(current_op->var); - VisitExpr(current_op->value); - if (const LetNode* new_op = current_op->body.as()) { - stack.push(new_op); - } else { - is_anormal = false; - } - } - while (stack.size()) { - const LetNode* current_op = stack.top(); - stack.pop(); - VisitExpr(current_op->body); - visit_counter_[current_op] += 1; - } + auto pre_visit = [this](const LetNode* op) { + this->VisitExpr(op->var); + this->VisitExpr(op->value); + }; + auto post_visit = [this](const LetNode* op) { + this->VisitExpr(op->body); + this->visit_counter_[op] += 1; + }; + ExpandANormalForm(op, pre_visit, post_visit); } }; From a10993c2412049067d0d4b7c3c30049105c40f32 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Mon, 1 Feb 2021 13:03:15 -0700 Subject: [PATCH 4/5] refactor missed pass --- src/relay/transforms/fuse_ops.cc | 37 ++++++++++++-------------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 09ac785dcc02..0a1440b97ab1 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -315,29 +315,20 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const LetNode* op) final { // do not fuse through let. - std::stack stack; - stack.push(op); - bool is_anormal = true; - while (is_anormal) { - const LetNode* current_op = stack.top(); - this->Update(current_op->var, nullptr, kOpaque); - this->Update(current_op->value, nullptr, kOpaque); - this->Update(current_op->body, nullptr, kOpaque); - VisitExpr(current_op->var); - VisitExpr(current_op->value); - if (const LetNode* new_op = current_op->body.as()) { - stack.push(new_op); - } else { - is_anormal = false; - } - } - while (stack.size()) { - const LetNode* current_op = stack.top(); - stack.pop(); - VisitExpr(current_op->body); - visit_counter_[current_op] += 1; - this->AddNode(current_op); - } + auto pre_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + this->Update(op->var, nullptr, kOpaque); + this->Update(op->value, nullptr, kOpaque); + this->Update(op->body, nullptr, kOpaque); + VisitExpr(op->var); + VisitExpr(op->value); + }; + auto post_visit = [this](const LetNode* op) { + VisitExpr(op->body); + visit_counter_[op] += 1; + this->AddNode(op); + }; + ExpandANormalForm(op, pre_visit, post_visit); } void VisitExpr_(const IfNode* op) final { From 15cb245c49ef6f2e91ef882d19cf32636fa61c8c Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Mon, 1 Feb 2021 13:48:54 -0700 Subject: [PATCH 5/5] add explict use of to lamdbas --- src/relay/transforms/de_duplicate.cc | 7 ++++--- src/relay/transforms/fold_constant.cc | 8 ++++---- src/relay/transforms/fuse_ops.cc | 18 +++++++++--------- src/relay/transforms/type_infer.cc | 26 +++++++++++++------------- 4 files changed, 30 insertions(+), 29 deletions(-) diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index 1cddc55981b3..2fd88736bf31 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -66,13 +66,14 @@ Expr DeDup(const Expr& e) { std::unordered_map new_vars; auto pre_visit = [this, &new_vars](const LetNode* op) { Expr expr = GetRef(op); - new_vars[expr] = Fresh(op->var); + new_vars[expr] = this->Fresh(op->var); // Rely on the Memoizer to cache pre-visit values - VisitExpr(op->value); + this->VisitExpr(op->value); }; auto post_visit = [this, &new_vars](const LetNode* op) { Expr expr = GetRef(op); - memo_[expr] = Let(new_vars[expr], VisitExpr(op->value), VisitExpr(op->body)); + this->memo_[expr] = + Let(new_vars[expr], this->VisitExpr(op->value), this->VisitExpr(op->body)); }; ExpandANormalForm(op, pre_visit, post_visit); return memo_[GetRef(op)]; diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 5ea01fc698ca..0689263cca77 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -96,7 +96,7 @@ class ConstantFolder : public MixedModeMutator { // Rely on the Memoizer to cache pre-visit values Expr value = this->Mutate(op->value); if (value.as()) { - memo_[op->var] = value; + this->memo_[op->var] = value; } else { this->Mutate(op->var); } @@ -106,14 +106,14 @@ class ConstantFolder : public MixedModeMutator { // Rely on the Memoizer to cache pre-visit values Expr value = this->Mutate(op->value); if (value.as()) { - memo_[expr] = this->Mutate(op->body); + this->memo_[expr] = this->Mutate(op->body); } else { Var var = Downcast(this->Mutate(op->var)); Expr body = this->Mutate(op->body); if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { - memo_[expr] = expr; + this->memo_[expr] = expr; } else { - memo_[expr] = Let(var, value, body); + this->memo_[expr] = Let(var, value, body); } } }; diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 0a1440b97ab1..eaef0b905079 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -320,12 +320,12 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->Update(op->var, nullptr, kOpaque); this->Update(op->value, nullptr, kOpaque); this->Update(op->body, nullptr, kOpaque); - VisitExpr(op->var); - VisitExpr(op->value); + this->VisitExpr(op->var); + this->VisitExpr(op->value); }; auto post_visit = [this](const LetNode* op) { - VisitExpr(op->body); - visit_counter_[op] += 1; + this->VisitExpr(op->body); + this->visit_counter_[op] += 1; this->AddNode(op); }; ExpandANormalForm(op, pre_visit, post_visit); @@ -932,15 +932,15 @@ class FuseMutator : private MixedModeMutator { }; auto post_visit = [this](const LetNode* op) { // Rely on the Memoizer to cache pre-visit values - Var var = Downcast(VisitExpr(op->var)); - Expr value = VisitExpr(op->value); + Var var = Downcast(this->VisitExpr(op->var)); + Expr value = this->VisitExpr(op->value); // Visit body and cache the op - Expr body = VisitExpr(op->body); + Expr body = this->VisitExpr(op->body); auto expr = GetRef(op); if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { - memo_[expr] = expr; + this->memo_[expr] = expr; } else { - memo_[expr] = Let(var, value, body); + this->memo_[expr] = Let(var, value, body); } }; ExpandANormalForm(op, pre_visit, post_visit); diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 560452d055da..b4ccd1659865 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -347,25 +347,25 @@ class TypeInferencer : private ExprFunctor, Type let_type = IncompleteType(Kind::kType); if (is_functional_literal) { - let_type = GetType(op->var); - type_map_[op->var].checked_type = let_type; + let_type = this->GetType(op->var); + this->type_map_[op->var].checked_type = let_type; } if (op->var->type_annotation.defined()) { - let_type = Unify(let_type, op->var->type_annotation, op->span); + let_type = this->Unify(let_type, op->var->type_annotation, op->span); } - Type vtype = GetType(op->value); - let_type = Unify(let_type, vtype, op->span); + Type vtype = this->GetType(op->value); + let_type = this->Unify(let_type, vtype, op->span); - ICHECK(is_functional_literal || !type_map_.count(op->var)); + ICHECK(is_functional_literal || !this->type_map_.count(op->var)); // NOTE: no scoping is necessary because var are unique in program - type_map_[op->var].checked_type = let_type; + this->type_map_[op->var].checked_type = let_type; }; auto post_visit = [this](const LetNode* op) { Expr expr = GetRef(op); - memo_[expr] = GetType(op->body); - type_map_[expr].checked_type = memo_[expr]; + this->memo_[expr] = this->GetType(op->body); + this->type_map_[expr].checked_type = this->memo_[expr]; }; ExpandANormalForm(let, pre_visit, post_visit); return memo_[GetRef(let)]; @@ -618,10 +618,10 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator { }; auto post_visit = [this](const LetNode* op) { Expr expr = GetRef(op); - Var var = Downcast(VisitExpr(op->var)); - Expr value = VisitExpr(op->value); - Expr body = VisitExpr(op->body); - memo_[expr] = AttachCheckedType(op, Let(var, value, body)); + Var var = Downcast(this->VisitExpr(op->var)); + Expr value = this->VisitExpr(op->value); + Expr body = this->VisitExpr(op->body); + this->memo_[expr] = this->AttachCheckedType(op, Let(var, value, body)); }; ExpandANormalForm(op, pre_visit, post_visit); return memo_[GetRef(op)];