Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,11 @@ TVM_DLL Pass ToANormalForm();
/*!
* \brief ToANormalForm but on incomplete graph.
*
* \param maybe_mod optional module holding definitions for global vars in \p expr
* \param expr the graph.
*
* \return The transformed program.
*/
TVM_DLL Expr ToANormalForm(const Optional<IRModule>& maybe_mod, const Expr& expr);
TVM_DLL Expr ToANormalForm(const Expr& expr);

/*!
* \brief Turn an expression into continuation passing style(CPS).
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/higher_order_gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ struct ReverseAD : ExprMutator {
return Call(bpv, {});
});
Expr nbp = Function({}, nbp_body, TupleType::Empty(), {});
ll->Push(RefWrite(bp, transform::ToANormalForm(mod, nbp)));
ll->Push(RefWrite(bp, transform::ToANormalForm(nbp)));
// TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that.
return ret;
});
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg);
Scope LCA(Scope lhs, Scope rhs);

// For basic block normal form.
Expr ToBasicBlockNormalFormAux(const Optional<IRModule>& maybe_mod, const Expr& e);
Expr ToBasicBlockNormalFormAux(const Expr& e);

// ToANormalForm for expressions and as a Pass are declared in transform.h

Expand Down
38 changes: 22 additions & 16 deletions src/relay/transforms/to_a_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,25 +149,31 @@ namespace {
*/
class Fill : ExprFunctor<Expr(const Expr&, const Var&)>, private transform::LexicalOnDeviceMixin {
public:
static Expr ToANormalForm(const Optional<IRModule>& maybe_mod, const Expr& e,
const DependencyGraph& dg, NodeScopeMap* node_scope) {
Fill fi(maybe_mod, dg, node_scope, nullptr);
static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope) {
Fill fi(dg, node_scope, nullptr);
return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e));
}

// For basic block normal form, bind expressions only if the original expression's scope
// should be lifted
static Expr ToBasicBlockNormalForm(const Optional<IRModule>& maybe_mod, const Expr& e,
const DependencyGraph& dg, NodeScopeMap* node_scope,
ExprSet* lifted) {
Fill fi(maybe_mod, dg, node_scope, lifted);
static Expr ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg,
NodeScopeMap* node_scope, ExprSet* lifted) {
Fill fi(dg, node_scope, lifted);
return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e));
}

private:
Fill(const Optional<IRModule>& maybe_mod, const DependencyGraph& dg, NodeScopeMap* node_scope,
ExprSet* include_set)
: transform::LexicalOnDeviceMixin(maybe_mod),
// Note: Conversion to ANF needn't care about the devices for global vars since all that can
// happen with them is to go from:
// ...@g...
// to:
// let %x = @g;
// ...
// ...%x...
// In that case the code will ask for the device for @g, get kInvalidDeviceType, then
// MaybeOnDevice @g, which is always a no-op.
Fill(const DependencyGraph& dg, NodeScopeMap* node_scope, ExprSet* include_set)
: transform::LexicalOnDeviceMixin(Optional<IRModule>()),
dg_(dg),
node_scope_(node_scope),
include_set_(include_set) {}
Expand Down Expand Up @@ -373,7 +379,7 @@ IRModule ModuleToANormalForm(const IRModule& mod) {
if (const auto* n = it.second.as<FunctionNode>()) {
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
Function func = GetRef<Function>(n);
Function ret = Downcast<Function>(transform::ToANormalForm(mod, func));
Function ret = Downcast<Function>(transform::ToANormalForm(func));
ICHECK_EQ(FreeVars(ret).size(), 0) << "rewritten:" << std::endl
<< PrettyPrint(ret) << std::endl
<< "should not have free vars: " << FreeVars(ret);
Expand All @@ -394,7 +400,7 @@ IRModule ModuleToANormalForm(const IRModule& mod) {

} // namespace

Expr ToBasicBlockNormalFormAux(const Optional<IRModule>& maybe_mod, const Expr& e) {
Expr ToBasicBlockNormalFormAux(const Expr& e) {
// calculate all the dependency between nodes.
support::Arena arena;
DependencyGraph dg = DependencyGraph::Create(&arena, e);
Expand All @@ -403,12 +409,12 @@ Expr ToBasicBlockNormalFormAux(const Optional<IRModule>& maybe_mod, const Expr&
* We also record the set of expressions whose scope is lifted.
*/
std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
return Fill::ToBasicBlockNormalForm(maybe_mod, e, dg, &scopes.first, &scopes.second);
return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second);
}

namespace transform {

Expr ToANormalForm(const Optional<IRModule>& maybe_mod, const Expr& e) {
Expr ToANormalForm(const Expr& e) {
/* When you lift a lambda, what is inside is also being lift.
*
* So we must determine the scope of the lambda before determining the scope of it's body.
Expand All @@ -431,7 +437,7 @@ Expr ToANormalForm(const Optional<IRModule>& maybe_mod, const Expr& e) {
* We do an additional pass to fill all the LetList and we are done.
*/
std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
return Fill::ToANormalForm(maybe_mod, e, dg, &scopes.first);
return Fill::ToANormalForm(e, dg, &scopes.first);
}

Pass ToANormalForm() {
Expand All @@ -445,7 +451,7 @@ TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed([]() {
});

TVM_REGISTER_GLOBAL("relay._transform.ToANormalFormExpr").set_body_typed([](const Expr& e) {
return ToANormalForm(Optional<IRModule>(), e);
return ToANormalForm(e);
});

} // namespace transform
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/to_basic_block_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ IRModule ToBasicBlockNormalForm(const IRModule& mod) {
if (const auto* n = it.second.as<FunctionNode>()) {
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
Function func = GetRef<Function>(n);
Function ret = Downcast<Function>(ToBasicBlockNormalFormAux(mod, func));
Function ret = Downcast<Function>(ToBasicBlockNormalFormAux(func));
VLOG(1) << "rewritten:" << std::endl
<< PrettyPrint(func) << std::endl
<< "to BasicBlockANF:" << std::endl
Expand Down