From 19b77836afeb8fd48d19bbc7e7f8138392d06040 Mon Sep 17 00:00:00 2001 From: Logan Weber Date: Tue, 29 Oct 2019 19:11:16 -0700 Subject: [PATCH 1/9] Fix constructor pretty printing --- src/relay/ir/pretty_printer.cc | 6 +++++- tests/python/relay/test_ir_text_printer.py | 24 ++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index b2a8396706f2..f42069b99603 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -669,7 +669,7 @@ class PrettyPrinter : Doc VisitExpr_(const ConstructorNode* n) final { Doc doc; doc << n->name_hint; - if (n->inputs.size() != 0) { + if (in_adt_def_ && n->inputs.size() != 0) { doc << "("; std::vector inputs; for (Type input : n->inputs) { @@ -775,6 +775,7 @@ class PrettyPrinter : } Doc VisitType_(const TypeDataNode* node) final { + in_adt_def_ = true; Doc doc; doc << "type " << Print(node->header); @@ -802,6 +803,7 @@ class PrettyPrinter : adt_body << ","; } doc << Brace(adt_body); + in_adt_def_ = false; return doc; } @@ -876,6 +878,8 @@ class PrettyPrinter : TextMetaDataContext meta_; /*! \brief counter of temporary variable */ size_t temp_var_counter_{0}; + /*! \brief whether the printer is currently in an ADT definition */ + bool in_adt_def_; /*! \brief arena for dependency graph */ common::Arena arena_; /*! \brief dependency graph of the expr */ diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 0d6a02e6c8e4..2dc082d50b5c 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -218,6 +218,29 @@ def test_zeros(): x = relay.op.zeros([], "float32") astext(x) + +def test_unapplied_constructor(): + type_def_str = r""" +type List[A] { + Cons(A, List[A]), + Nil, +} + """ + main_def_str = r""" +def @main[A]() -> fn (A, List[A]) -> List[A] { + Cons +} + """ + + mod = relay.fromtext(SEMVER + type_def_str + main_def_str) + mod_str = str(mod) + + # ensure constructors are printed correctly in type definitions (with their + # signature) and as exprs (without their signature) + assert type_def_str.strip() in mod_str + assert main_def_str.strip() in mod_str + + if __name__ == "__main__": do_print[0] = True test_lstm() @@ -239,3 +262,4 @@ def test_zeros(): test_let_if_scope() test_variable_name() test_call_node_order() + test_unapplied_constructor() From a213084702f26caa5ebacdbdaec714501fc1db24 Mon Sep 17 00:00:00 2001 From: Logan Weber Date: Tue, 29 Oct 2019 19:18:12 -0700 Subject: [PATCH 2/9] Make Module::HasDef name consistent with API --- include/tvm/relay/module.h | 14 +++++++------- src/relay/ir/alpha_equal.cc | 2 +- src/relay/ir/module.cc | 9 ++++----- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 1ef7ca88280e..0d3f46cd3cc0 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -144,6 +144,13 @@ class ModuleNode : public RelayNode { */ TVM_DLL bool ContainGlobalVar(const std::string& name) const; + /*! + * \brief Check if the global_type_var_map_ contains a global type variable. + * \param name The variable name. + * \returns true if contains, otherise false. + */ + TVM_DLL bool ContainGlobalTypeVar(const std::string& name) const; + /*! * \brief Lookup a global function by its variable. * \param str The unique string specifying the global variable. @@ -198,13 +205,6 @@ class ModuleNode : public RelayNode { */ TVM_DLL TypeData LookupDef(const std::string& var) const; - /*! - * \brief Check if a global type definition exists - * \param var The name of the global type definition. - * \return Whether the definition exists. - */ - TVM_DLL bool HasDef(const std::string& var) const; - /*! * \brief Look up a constructor by its tag. * \param tag The tag for the constructor. diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 0dbcf992e028..6aa46fe82a45 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -69,7 +69,7 @@ class AlphaEqualHandler: } if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false; for (const auto& p : lhsm->type_definitions) { - if (!rhsm->HasDef(p.first->var->name_hint) || + if (!rhsm->ContainGlobalVar(p.first->var->name_hint) || !Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) { return false; } diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 960c28f94c76..e138518c2c58 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -68,6 +68,10 @@ bool ModuleNode::ContainGlobalVar(const std::string& name) const { return global_var_map_.find(name) != global_var_map_.end(); } +bool ModuleNode::ContainGlobalTypeVar(const std::string& name) const { + return global_type_var_map_.find(name) != global_type_var_map_.end(); +} + GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const { auto it = global_var_map_.find(name); CHECK(it != global_var_map_.end()) @@ -239,11 +243,6 @@ TypeData ModuleNode::LookupDef(const std::string& name) const { return this->LookupDef(id); } -bool ModuleNode::HasDef(const std::string& name) const { - auto it = global_type_var_map_.find(name); - return it != global_type_var_map_.end(); -} - Constructor ModuleNode::LookupTag(const int32_t tag) { auto it = constructor_tag_map_.find(tag); CHECK(it != constructor_tag_map_.end()) From 3f8847c1e876b659e286df04bf646e3aaf63bf9b Mon Sep 17 00:00:00 2001 From: Logan Weber Date: Tue, 29 Oct 2019 21:26:04 -0700 Subject: [PATCH 3/9] Add VM constructor compilation via eta expansion --- include/tvm/relay/transform.h | 9 +- python/tvm/relay/std/prelude.rly | 14 +- python/tvm/relay/transform.py | 15 +- src/relay/backend/vm/compiler.cc | 3 + src/relay/backend/vm/lambda_lift.cc | 22 ++- src/relay/ir/module.cc | 3 +- src/relay/pass/eta_expand.cc | 160 ++++++++++++++++----- src/relay/pass/type_infer.cc | 2 +- tests/python/relay/test_ir_text_printer.py | 2 - tests/python/relay/test_pass_eta_expand.py | 75 +++++++--- 10 files changed, 226 insertions(+), 79 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 82144d76e565..e110203de530 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -552,17 +552,20 @@ TVM_DLL Pass Legalize(const std::string& legalize_map_attr_name = "FTVMLegalize" TVM_DLL Pass CanonicalizeCast(); /*! - * \brief Add abstraction over a function + * \brief Add abstraction over a constructor or global variable bound to a function. * * For example: `square` is transformed to - * `fun x -> square x`. + * `fn (%x: int32) -> int32 { square(x) }`. * * See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion * for more details. * + * \param expand_constructor Whether to expand constructors. + * \param expand_global_var Whether to expand global variables. + * * \return The pass. */ -TVM_DLL Pass EtaExpand(); +TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var); /*! * \brief Print the IR for a module to help debugging. diff --git a/python/tvm/relay/std/prelude.rly b/python/tvm/relay/std/prelude.rly index a5c2c9f8a9cb..15119bb3ddb0 100644 --- a/python/tvm/relay/std/prelude.rly +++ b/python/tvm/relay/std/prelude.rly @@ -159,12 +159,7 @@ def @sum(%xs: List[Tensor[(), int32]]) { * Concatenates two lists. */ def @concat[A](%xs: List[A], %ys: List[A]) -> List[A] { - let %updater = fn(%x: A, %xss: List[A]) -> List[A] { - Cons(%x, %xss) - }; - @foldr(%updater, %ys, %xs) - // TODO(weberlo): write it like below, once VM constructor compilation is fixed - // @foldr(Cons, %ys, %xs) + @foldr(Cons, %ys, %xs) } /* @@ -199,12 +194,7 @@ def @zip[A, B](%xs: List[A], %ys: List[B]) -> List[(A, B)] { * Reverses a list. */ def @rev[A](%xs: List[A]) -> List[A] { - let %updater = fn(%xss: List[A], %x: A) -> List[A] { - Cons(%x, %xss) - }; - @foldl(%updater, Nil, %xs) - // TODO(weberlo): write it like below, once VM constructor compilation is fixed - // @foldl(@flip(Cons), Nil, %xs) + @foldl(@flip(Cons), Nil, %xs) } /* diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index ad1aecf7060a..30f79b0e1563 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -513,15 +513,23 @@ def ToCPS(expr, mod=None): return _transform.to_cps(expr, mod) -def EtaExpand(): - """Add abstraction over a function +def EtaExpand(expand_constructor=False, expand_global_var=False): + """Add abstraction over a constructor or global variable bound to a function + + Parameters + ---------- + expand_constructor: bool + Whether to expand constructors. + + expand_global_var: bool + Whether to expand global variables. Returns ------- ret: tvm.relay.Pass The registered pass that eta expands an expression. """ - return _transform.EtaExpand() + return _transform.EtaExpand(expand_constructor, expand_global_var) def ToGraphNormalForm(): @@ -938,6 +946,7 @@ def create_function_pass(pass_arg): return create_function_pass(pass_func) return create_function_pass + @function_pass(opt_level=1) class ChangeBatch: """ diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 3cfea5c2e0db..e8247213e62f 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -872,6 +872,9 @@ Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) pass_seqs.push_back(transform::Legalize()); } + pass_seqs.push_back(transform::EtaExpand( + /* expand_constructor */ true, /* expand_global_var */ false)); + pass_seqs.push_back(transform::SimplifyInference()); PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { Expr expr = args[0]; diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 6290ef7c6e93..6ef31e626dbb 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -61,8 +61,8 @@ Function MarkClosure(const Function& func) { * We will lift a function out into a global which takes the set of the free * vars and then return the new created function. */ -struct LambdaLifter : ExprMutator { - Module module_; +class LambdaLifter : public ExprMutator { + public: explicit LambdaLifter(const Module& module) : module_(module) {} Expr VisitExpr_(const FunctionNode* func_node) final { @@ -100,8 +100,8 @@ struct LambdaLifter : ExprMutator { // The "inner" function should be used to generate the // code for the closure. Function lifted_func; - if (free_vars.size() == 0) { - lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, free_type_vars); + if (free_vars.size() == 0 && free_type_vars.size() == 0) { + lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, body->type_params); } else { lifted_func = FunctionNode::make(free_vars, body, func->func_type_annotation(), free_type_vars); @@ -114,8 +114,15 @@ struct LambdaLifter : ExprMutator { auto name = GenerateName(lifted_func); auto global = GlobalVarNode::make(name); - // Add the lifted function to the module. - module_->Add(global, lifted_func); + if (module_->ContainGlobalVar(name)) { + const auto existing_func = module_->Lookup(name); + CHECK(AlphaEqual(lifted_func, existing_func)) << "lifted function hash collision"; + // If an identical function already exists, use its global var. + global = module_->GetGlobalVar(name); + } else { + // Add the lifted function to the module. + module_->Add(global, lifted_func); + } if (free_vars.size() == 0) { return std::move(global); @@ -145,6 +152,9 @@ struct LambdaLifter : ExprMutator { } return module_; } + + private: + Module module_; }; } // namespace vm diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index e138518c2c58..3bd8d59aaf49 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -335,7 +335,8 @@ TVM_REGISTER_API("relay._module.Module_Add") } else if (val->IsInstance()) { GlobalVar gv = Downcast(val); auto mod_copy = Module(make_node(*mod.operator->())); - mod_copy = transform::EtaExpand()(mod_copy); + mod_copy = transform::EtaExpand( + /* expand_constructor */ false, /* expand_global_var */ true)(mod_copy); auto func = mod_copy->Lookup(gv->name_hint); mod->Add(var, Downcast(func), update); } else { diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc index a5d04871ba95..133ee829a193 100644 --- a/src/relay/pass/eta_expand.cc +++ b/src/relay/pass/eta_expand.cc @@ -20,57 +20,147 @@ /*! * \file eta_expand.cc * - * \brief Add abstraction over a function. For example, abs will become (fun x -> abs x). + * \brief Add abstraction over a constructor or global variable bound to a function. * */ #include +#include +#include "../ir/type_functor.h" #include namespace tvm { namespace relay { +namespace eta_expand { -Expr EtaExpand(const Expr& e, const Module& mod) { - tvm::Array original_params; - tvm::Array params; - tvm::Array args; - tvm::Array original_type_params; - Type ret_type; - - if (e->IsInstance()) { - auto gvar_node = e.as(); - auto func = mod->Lookup(GetRef(gvar_node)); - original_params = func->params; - original_type_params = func->type_params; - ret_type = func->ret_type; - } else { - CHECK(e->IsInstance()); - auto func = GetRef(e.as()); - original_params = func->params; - original_type_params = func->type_params; - ret_type = func->ret_type; +/*! + * \brief mutator to replace type variables with fresh ones, while maintaining alpha equality + */ +class TypeVarReplacer : public TypeMutator { + public: + TypeVarReplacer() : replace_map_({}) {} + + Type VisitType_(const TypeVarNode* type_var_node) final { + const auto type_var = GetRef(type_var_node); + if (replace_map_.find(type_var) == replace_map_.end()) { + replace_map_[type_var] = TypeVarNode::make("A", Kind::kType); + } + return replace_map_[type_var]; } - for (size_t i = 0; i < original_params.size(); ++i) { - auto var = VarNode::make("a", original_params[i]->type_annotation); - params.push_back(var); - args.push_back(var); + private: + /*! \brief variable replacement map to remap old type vars to fresh ones */ + std::unordered_map replace_map_; +}; + +/*! + * \brief mutator to perform eta expansion on all functions in a module + */ +class EtaExpander : public ExprMutator { + public: + explicit EtaExpander( + const Module& mod, + bool expand_constructor, + bool expand_global_var) + : mod_(mod) + , type_var_replacer_(TypeVarReplacer()) + , expand_constructor_(expand_constructor) + , expand_global_var_(expand_global_var) { + CHECK(expand_constructor || expand_global_var) + << "must expand at least one language feature"; } - auto new_func = - FunctionNode::make(args, CallNode::make(e, params), ret_type, original_type_params); + Module Expand() { + for (GlobalVar global_var : mod_->GetGlobalVars()) { + const Function func = mod_->Lookup(global_var); + const Function new_func = Downcast(VisitExpr(func)); + mod_->Update(global_var, new_func); + } + return mod_; + } - return std::move(new_func); -} + Expr VisitExpr_(const CallNode* call) final { + // we don't need to expand constructors when they are being called, so we + // prevent them being visited here + Expr new_op = call->op; + if (!call->op.as()) { + new_op = VisitExpr(new_op); + } + tvm::Array new_args; + for (const auto& arg : call->args) { + new_args.push_back(VisitExpr(arg)); + } + return CallNode::make(new_op, new_args, call->attrs, call->type_args); + } + + Expr VisitExpr_(const ConstructorNode* cons_node) final { + Constructor cons = GetRef(cons_node); + if (!expand_constructor_) { + return std::move(cons); + } + // NOTE: we only reach this case if the constructor is not being applied to any arguments + tvm::Array params; + for (const auto& type : cons->inputs) { + Type param_type = type_var_replacer_.VisitType(type); + params.push_back(VarNode::make("eta_expand_param", param_type)); + } + tvm::Array type_params; + TypeData adt_def = mod_->LookupDef(cons->belong_to); + for (const auto& type_var : adt_def->type_vars) { + type_params.push_back(type_var_replacer_.VisitType(type_var)); + } + Expr body = CallNode::make(cons, params, Attrs()); + Type ret_type = TypeCallNode::make(cons->belong_to, type_params); + + return FunctionNode::make( + Downcast>(params), + body, + ret_type, + Downcast>(type_params)); + } + + Expr VisitExpr_(const GlobalVarNode* gvar_node) final { + GlobalVar gvar = GetRef(gvar_node); + if (!expand_global_var_) { + return std::move(gvar); + } + + const auto func = mod_->Lookup(gvar); + tvm::Array params; + tvm::Array args; + for (size_t i = 0; i < func->params.size(); ++i) { + auto var = VarNode::make("eta_expand_param", func->params[i]->type_annotation); + params.push_back(var); + args.push_back(var); + } + + return FunctionNode::make( + args, + CallNode::make(gvar, params), + func->ret_type, + func->type_params); + } + + private: + /*! \brief reference to module being expanded */ + const Module mod_; + /*! \brief type variable replacer */ + TypeVarReplacer type_var_replacer_; + /*! \brief whether to expand constructor nodes */ + bool expand_constructor_; + /*! \brief whether to expand global variable nodes */ + bool expand_global_var_; +}; + +} // namespace eta_expand namespace transform { -Pass EtaExpand() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { - return Downcast(EtaExpand(f, m)); - }; - Pass expanded = CreateFunctionPass(pass_func, 1, "EtaExpand", {}); - return Sequential({expanded, InferType()}); +Pass EtaExpand(bool expand_constructor, bool expand_global_var) { + runtime::TypedPackedFunc pass_func = + [=](Module mod, PassContext pc) { + return eta_expand::EtaExpander(mod, expand_constructor, expand_global_var).Expand(); + }; + return CreateModulePass(pass_func, 1, "EtaExpand", {}); } TVM_REGISTER_API("relay._transform.EtaExpand") diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index bc84bddaad79..9d6878170bb5 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -653,7 +653,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { } Expr VisitExpr_(const ConstructorNode* op) final { - return GetRef(op); + return AttachCheckedType(op); } Expr VisitExpr_(const MatchNode* op) final { diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 2dc082d50b5c..6426bf3410c8 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -231,10 +231,8 @@ def @main[A]() -> fn (A, List[A]) -> List[A] { Cons } """ - mod = relay.fromtext(SEMVER + type_def_str + main_def_str) mod_str = str(mod) - # ensure constructors are printed correctly in type definitions (with their # signature) and as exprs (without their signature) assert type_def_str.strip() in mod_str diff --git a/tests/python/relay/test_pass_eta_expand.py b/tests/python/relay/test_pass_eta_expand.py index 73c3a4eb4073..b9eb2a1e692d 100644 --- a/tests/python/relay/test_pass_eta_expand.py +++ b/tests/python/relay/test_pass_eta_expand.py @@ -14,27 +14,70 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import os + +import numpy as np + +import tvm from tvm import relay -import tvm.relay.module as _module import tvm.relay.transform as _transform -def test_eta_expand_basic(): - x = relay.var('x', 'int32') - orig = relay.Function([x], x) - mod = _module.Module.from_expr(orig) - seq = _transform.Sequential([_transform.EtaExpand()]) +def test_eta_expand_global_var(): + mod = relay.fromtext(r""" + v0.0.4 + def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] { + %x + } + def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { + @aux + } + """) + seq = _transform.Sequential([_transform.EtaExpand(expand_global_var=True)]) with _transform.PassContext(opt_level=3): mod = seq(mod) + expected = relay.fromtext(r""" + v0.0.4 + def @aux(%x: Tensor[(), int32]) -> Tensor[(), int32] { + %x + } + def @main() -> (fn(Tensor[(), int32]) -> Tensor[(), int32]) { + fn (%x: Tensor[(), int32]) -> Tensor[(), int32] { + @aux(%x) + } + } + """) + relay.analysis.assert_graph_equal(mod['main'], expected['main']) + - got = mod["main"] +def test_eta_expand_constructor(): + mod = relay.fromtext(r""" + v0.0.4 + type List[A] { + Cons(A, List[A]), + Nil, + } + def @main[A]() -> (fn(A, List[A]) -> List[A]) { + Cons + } + """) + seq = _transform.Sequential([_transform.EtaExpand(expand_constructor=True)]) + with _transform.PassContext(opt_level=3): + mod = seq(mod) + expected = relay.fromtext(r""" + v0.0.4 + type List[A] { + Cons(A, List[A]), + Nil, + } + def @main[A]() -> (fn(A, List[A]) -> List[A]) { + fn [A](%x: A, %xs: List[A]) -> List[A] { + Cons(%x, %xs) + } + } + """) + relay.analysis.assert_graph_equal(mod['main'], expected['main']) - y = relay.var('y', 'int32') - expected = relay.Function([y], orig(y)) - gv = relay.GlobalVar("gv") - mod[gv] = expected - mod = _transform.InferType()(mod) - expected = mod["gv"] - assert(relay.analysis.alpha_equal(got, expected)) -if __name__ == "__main__": - test_eta_expand_basic() +if __name__ == '__main__': + test_eta_expand_global_var() + test_eta_expand_constructor() From 93e8257829e7606bed886516decb9159d51a3884 Mon Sep 17 00:00:00 2001 From: Logan Weber Date: Wed, 30 Oct 2019 08:35:45 -0700 Subject: [PATCH 4/9] Lint --- src/relay/pass/eta_expand.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc index 133ee829a193..3e5b607e5681 100644 --- a/src/relay/pass/eta_expand.cc +++ b/src/relay/pass/eta_expand.cc @@ -23,10 +23,10 @@ * \brief Add abstraction over a constructor or global variable bound to a function. * */ +#include #include #include #include "../ir/type_functor.h" -#include namespace tvm { namespace relay { From 9a847619cf5a3c270fb33d5a09765ed4c21cd0b1 Mon Sep 17 00:00:00 2001 From: Logan Weber Date: Sun, 3 Nov 2019 17:46:26 -0800 Subject: [PATCH 5/9] Fix CI --- src/relay/backend/interpreter.cc | 9 +++++++++ src/relay/backend/vm/compiler.cc | 1 + src/relay/ir/alpha_equal.cc | 2 +- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 01693e5b3673..74a124d08faf 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -789,6 +790,14 @@ CreateInterpreter( Module mod, DLContext context, Target target) { + // eta expand to support constructors in argument position + transform::Sequential seq({ + transform::EtaExpand( + /* expand_constructor */ true, /* expand_global_var */ false)}); + transform::PassContext pass_ctx = transform::PassContext::Current(); + tvm::With ctx(pass_ctx); + mod = seq(mod); + auto intrp = std::make_shared(mod, context, target); auto packed = [intrp](Expr expr) { auto f = DetectFeature(expr); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e8247213e62f..a6b77b311853 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -872,6 +872,7 @@ Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets) pass_seqs.push_back(transform::Legalize()); } + // eta expand to support constructors in argument position pass_seqs.push_back(transform::EtaExpand( /* expand_constructor */ true, /* expand_global_var */ false)); diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 6aa46fe82a45..df91f794f6d1 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -69,7 +69,7 @@ class AlphaEqualHandler: } if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false; for (const auto& p : lhsm->type_definitions) { - if (!rhsm->ContainGlobalVar(p.first->var->name_hint) || + if (!rhsm->ContainGlobalTypeVar(p.first->var->name_hint) || !Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) { return false; } From 91fc0fb2311cf626f8908a52d64f8262b3f67f24 Mon Sep 17 00:00:00 2001 From: Logan Weber Date: Sun, 10 Nov 2019 23:00:57 -0800 Subject: [PATCH 6/9] Fix failing test --- python/tvm/relay/std/prelude.rly | 1 + src/relay/backend/interpreter.cc | 16 +++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/std/prelude.rly b/python/tvm/relay/std/prelude.rly index 15119bb3ddb0..fa05d1a7bd98 100644 --- a/python/tvm/relay/std/prelude.rly +++ b/python/tvm/relay/std/prelude.rly @@ -158,6 +158,7 @@ def @sum(%xs: List[Tensor[(), int32]]) { /* * Concatenates two lists. */ + def @concat[A](%xs: List[A], %ys: List[A]) -> List[A] { @foldr(Cons, %ys, %xs) } diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 74a124d08faf..45283582bf05 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -790,13 +790,15 @@ CreateInterpreter( Module mod, DLContext context, Target target) { - // eta expand to support constructors in argument position - transform::Sequential seq({ - transform::EtaExpand( - /* expand_constructor */ true, /* expand_global_var */ false)}); - transform::PassContext pass_ctx = transform::PassContext::Current(); - tvm::With ctx(pass_ctx); - mod = seq(mod); + if (mod.defined()) { + // eta expand to support constructors in argument position + transform::Sequential seq({ + transform::EtaExpand( + /* expand_constructor */ true, /* expand_global_var */ false)}); + transform::PassContext pass_ctx = transform::PassContext::Current(); + tvm::With ctx(pass_ctx); + mod = seq(mod); + } auto intrp = std::make_shared(mod, context, target); auto packed = [intrp](Expr expr) { From 27d462c5ce8eeea3bb4e1f7183d4bb8d5a018323 Mon Sep 17 00:00:00 2001 From: Logan Weber Date: Tue, 12 Nov 2019 08:54:05 -0800 Subject: [PATCH 7/9] Address comment --- src/relay/pass/eta_expand.cc | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/relay/pass/eta_expand.cc b/src/relay/pass/eta_expand.cc index 3e5b607e5681..dca08cc834d1 100644 --- a/src/relay/pass/eta_expand.cc +++ b/src/relay/pass/eta_expand.cc @@ -20,7 +20,7 @@ /*! * \file eta_expand.cc * - * \brief Add abstraction over a constructor or global variable bound to a function. + * \brief Add an abstraction over constructors and/or global variables bound to a function. * */ #include @@ -57,14 +57,11 @@ class TypeVarReplacer : public TypeMutator { */ class EtaExpander : public ExprMutator { public: - explicit EtaExpander( - const Module& mod, - bool expand_constructor, - bool expand_global_var) - : mod_(mod) - , type_var_replacer_(TypeVarReplacer()) - , expand_constructor_(expand_constructor) - , expand_global_var_(expand_global_var) { + explicit EtaExpander(const Module& mod, bool expand_constructor, bool expand_global_var) + : mod_(mod), + type_var_replacer_(TypeVarReplacer()), + expand_constructor_(expand_constructor), + expand_global_var_(expand_global_var) { CHECK(expand_constructor || expand_global_var) << "must expand at least one language feature"; } From b9b0b4360523bb84a524fd9dc0482321c091d205 Mon Sep 17 00:00:00 2001 From: Logan Weber Date: Tue, 12 Nov 2019 21:09:04 -0800 Subject: [PATCH 8/9] Retrigger CI From da596884b1dc81c47f883a2c40926dc4aa18483d Mon Sep 17 00:00:00 2001 From: Logan Weber Date: Thu, 14 Nov 2019 11:42:06 -0800 Subject: [PATCH 9/9] Retrigger CI