From 17aff1cc44f9ab8c8927e43be2777a1bb36c8c48 Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Mon, 4 Oct 2021 09:22:59 -0700 Subject: [PATCH] Address Christopher's comments from #8788 We don't need the Optional on ToANormalForm and friends. --- include/tvm/relay/transform.h | 3 +- src/relay/transforms/higher_order_gradient.cc | 2 +- src/relay/transforms/pass_utils.h | 2 +- src/relay/transforms/to_a_normal_form.cc | 38 +++++++++++-------- .../transforms/to_basic_block_normal_form.cc | 2 +- 5 files changed, 26 insertions(+), 21 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 91f731410863..e740776d6d4f 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -165,12 +165,11 @@ TVM_DLL Pass ToANormalForm(); /*! * \brief ToANormalForm but on incomplete graph. * - * \param maybe_mod optional module holding definitions for global vars in \p expr * \param expr the graph. * * \return The transformed program. */ -TVM_DLL Expr ToANormalForm(const Optional& maybe_mod, const Expr& expr); +TVM_DLL Expr ToANormalForm(const Expr& expr); /*! * \brief Turn an expression into continuation passing style(CPS). diff --git a/src/relay/transforms/higher_order_gradient.cc b/src/relay/transforms/higher_order_gradient.cc index 3adff6e3099a..202275626d5d 100644 --- a/src/relay/transforms/higher_order_gradient.cc +++ b/src/relay/transforms/higher_order_gradient.cc @@ -293,7 +293,7 @@ struct ReverseAD : ExprMutator { return Call(bpv, {}); }); Expr nbp = Function({}, nbp_body, TupleType::Empty(), {}); - ll->Push(RefWrite(bp, transform::ToANormalForm(mod, nbp))); + ll->Push(RefWrite(bp, transform::ToANormalForm(nbp))); // TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that. return ret; }); diff --git a/src/relay/transforms/pass_utils.h b/src/relay/transforms/pass_utils.h index 5638804b4aa2..ed9409856871 100644 --- a/src/relay/transforms/pass_utils.h +++ b/src/relay/transforms/pass_utils.h @@ -227,7 +227,7 @@ std::pair CalcScope(const DependencyGraph& dg); Scope LCA(Scope lhs, Scope rhs); // For basic block normal form. -Expr ToBasicBlockNormalFormAux(const Optional& maybe_mod, const Expr& e); +Expr ToBasicBlockNormalFormAux(const Expr& e); // ToANormalForm for expressions and as a Pass are declared in transform.h diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 1dc45d38518e..c767770a8be8 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -149,25 +149,31 @@ namespace { */ class Fill : ExprFunctor, private transform::LexicalOnDeviceMixin { public: - static Expr ToANormalForm(const Optional& maybe_mod, const Expr& e, - const DependencyGraph& dg, NodeScopeMap* node_scope) { - Fill fi(maybe_mod, dg, node_scope, nullptr); + static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope) { + Fill fi(dg, node_scope, nullptr); return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e)); } // For basic block normal form, bind expressions only if the original expression's scope // should be lifted - static Expr ToBasicBlockNormalForm(const Optional& maybe_mod, const Expr& e, - const DependencyGraph& dg, NodeScopeMap* node_scope, - ExprSet* lifted) { - Fill fi(maybe_mod, dg, node_scope, lifted); + static Expr ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg, + NodeScopeMap* node_scope, ExprSet* lifted) { + Fill fi(dg, node_scope, lifted); return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e)); } private: - Fill(const Optional& maybe_mod, const DependencyGraph& dg, NodeScopeMap* node_scope, - ExprSet* include_set) - : transform::LexicalOnDeviceMixin(maybe_mod), + // Note: Conversion to ANF needn't care about the devices for global vars since all that can + // happen with them is to go from: + // ...@g... + // to: + // let %x = @g; + // ... + // ...%x... + // In that case the code will ask for the device for @g, get kInvalidDeviceType, then + // MaybeOnDevice @g, which is always a no-op. + Fill(const DependencyGraph& dg, NodeScopeMap* node_scope, ExprSet* include_set) + : transform::LexicalOnDeviceMixin(Optional()), dg_(dg), node_scope_(node_scope), include_set_(include_set) {} @@ -373,7 +379,7 @@ IRModule ModuleToANormalForm(const IRModule& mod) { if (const auto* n = it.second.as()) { if (n->GetAttr(attr::kCompiler).defined()) continue; Function func = GetRef(n); - Function ret = Downcast(transform::ToANormalForm(mod, func)); + Function ret = Downcast(transform::ToANormalForm(func)); ICHECK_EQ(FreeVars(ret).size(), 0) << "rewritten:" << std::endl << PrettyPrint(ret) << std::endl << "should not have free vars: " << FreeVars(ret); @@ -394,7 +400,7 @@ IRModule ModuleToANormalForm(const IRModule& mod) { } // namespace -Expr ToBasicBlockNormalFormAux(const Optional& maybe_mod, const Expr& e) { +Expr ToBasicBlockNormalFormAux(const Expr& e) { // calculate all the dependency between nodes. support::Arena arena; DependencyGraph dg = DependencyGraph::Create(&arena, e); @@ -403,12 +409,12 @@ Expr ToBasicBlockNormalFormAux(const Optional& maybe_mod, const Expr& * We also record the set of expressions whose scope is lifted. */ std::pair scopes = CalcScope(dg); - return Fill::ToBasicBlockNormalForm(maybe_mod, e, dg, &scopes.first, &scopes.second); + return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second); } namespace transform { -Expr ToANormalForm(const Optional& maybe_mod, const Expr& e) { +Expr ToANormalForm(const Expr& e) { /* When you lift a lambda, what is inside is also being lift. * * So we must determine the scope of the lambda before determining the scope of it's body. @@ -431,7 +437,7 @@ Expr ToANormalForm(const Optional& maybe_mod, const Expr& e) { * We do an additional pass to fill all the LetList and we are done. */ std::pair scopes = CalcScope(dg); - return Fill::ToANormalForm(maybe_mod, e, dg, &scopes.first); + return Fill::ToANormalForm(e, dg, &scopes.first); } Pass ToANormalForm() { @@ -445,7 +451,7 @@ TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed([]() { }); TVM_REGISTER_GLOBAL("relay._transform.ToANormalFormExpr").set_body_typed([](const Expr& e) { - return ToANormalForm(Optional(), e); + return ToANormalForm(e); }); } // namespace transform diff --git a/src/relay/transforms/to_basic_block_normal_form.cc b/src/relay/transforms/to_basic_block_normal_form.cc index 826006b0e603..931543d2640c 100644 --- a/src/relay/transforms/to_basic_block_normal_form.cc +++ b/src/relay/transforms/to_basic_block_normal_form.cc @@ -46,7 +46,7 @@ IRModule ToBasicBlockNormalForm(const IRModule& mod) { if (const auto* n = it.second.as()) { if (n->GetAttr(attr::kCompiler).defined()) continue; Function func = GetRef(n); - Function ret = Downcast(ToBasicBlockNormalFormAux(mod, func)); + Function ret = Downcast(ToBasicBlockNormalFormAux(func)); VLOG(1) << "rewritten:" << std::endl << PrettyPrint(func) << std::endl << "to BasicBlockANF:" << std::endl