From 2bd044d77edbfa4af45c32153b44f17ae8dbf157 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Mon, 31 Aug 2020 14:43:18 -0700 Subject: [PATCH 01/17] type args not automatically inferred... --- python/tvm/relay/transform/transform.py | 15 ++ src/relay/transforms/defunctionalization.cc | 250 ++++++++++++++++++ .../relay/test_pass_defunctionalization.py | 42 +++ 3 files changed, 307 insertions(+) create mode 100644 src/relay/transforms/defunctionalization.cc create mode 100644 tests/python/relay/test_pass_defunctionalization.py diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index de3f9861c96e..f64e80568ee1 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -736,6 +736,21 @@ def gradient(expr, mod=None, mode='higher_order'): return _ffi_api.gradient(expr, mod) raise Exception('unknown mode') +def Defunctionalization(expr, mod): + """ + Parameters + ---------- + expr : tvm.relay.Expr + The input expression, which is a Function or a GlobalVar. + + mod : tvm.IRModule + + Returns + ------- + expr : tvm.relay.Expr + The transformed expression. + """ + return _ffi_api.Defunctionalization(expr, mod) def to_cps(func, mod=None): """ diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc new file mode 100644 index 000000000000..bc532d9e4c88 --- /dev/null +++ b/src/relay/transforms/defunctionalization.cc @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file defunctionalization.cc + * + * \brief + */ + +#include +#include +#include +#include +#include +#include + +#include "../transforms/pass_util.h" +namespace tvm { +namespace relay { + +struct FuncTypeVisitor : TypeVisitor { + bool has_func_type; + FuncTypeVisitor() : has_func_type(false) {} + + void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; } +}; + +bool HasFuncType(const Expr& e) { + auto visitor = FuncTypeVisitor(); + visitor.VisitType(e->checked_type()); + return visitor.has_func_type; +} + +bool HasFuncType(const Type& t) { + auto visitor = FuncTypeVisitor(); + visitor.VisitType(t); + return visitor.has_func_type; +} + +bool IsHigherOrderFunc(const FuncType& t) { + bool higher_order = false; + for (auto arg: t->arg_types) { + higher_order |= HasFuncType(arg); + } + return higher_order |= HasFuncType(t->ret_type); +} + +class DefuncMutator : public ExprMutator { + public: + DefuncMutator(const IRModule& mod) : mod(mod), constructor_name(0) {} + + Expr VisitExpr_(const CallNode* op) { + auto op_func = op->op; + auto f = op_func.as(); + std::cout << op_func << std::endl; + CHECK(f) << "only calls to functions are supported so far"; + + // clone function and specialize if there are higher order functions + if (IsHigherOrderFunc(Downcast(f->checked_type()))) { + auto f_clone = Downcast(Clone(f, op->type_args)); + std::cout << f_clone << std::endl; + auto f_clone_type = Downcast(f_clone->checked_type()); + CHECK(FreeTypeVars(f_clone_type, mod).size() == 0) + << "free type vars in specialized function"; + CHECK(FreeVars(f_clone).size() == FreeVars(GetRef(f)).size()) + << "local closures not supported yet"; + CHECK(!HasFuncType(f_clone_type->ret_type)) << "returning function not supported yet"; + + Array args; + std::unordered_map applyVars; + for (size_t i = 0; i < f_clone_type->arg_types.size(); i++) { + if (f_clone_type->arg_types[i].as()) { + auto arg = EncodeFunctionArg(op->args[i], f_clone_type->arg_types[i].as()); + args.push_back(arg); + applyVars[f_clone->params[i]] = apply_map[f_clone->params[i]->checked_type()]; + } + + CHECK(!HasFuncType(f_clone_type->arg_types[i])) << "nested function type in parameter not supported yet"; + args.push_back(op->args[i]); + } + + auto new_func = ApplyVars(f_clone, applyVars); + + return Call(ExprMutator::VisitExpr(new_func), args); + } + return ExprMutator::VisitExpr(GetRef(op)); + } + + // Expr VisitExpr_(const LetNode* op) { + // var_map[op->var] = this->VisitExpr(op->value); + // return this->VisitExpr(op->body); + // } + + // Expr VisitExpr_(const VarNode* op) { + // if (var_map.count(GetRef(op)) != 0) { + // return var_map[GetRef(op)]; + // } + // return GetRef(op); + // } + + Expr VisitExpr_(const GlobalVarNode* op) { CHECK(false) << "global var not supported yet"; + throw std::runtime_error("GlobalVar not supported"); + } + + private: + IRModule mod; + // encode func type to ADT + std::unordered_map func_encoding; + std::unordered_map apply_map; + // use monotonically increasing integer to represent new constructor_name + unsigned int constructor_name; + + Expr ApplyVars(Expr body, const std::unordered_map& vars) { + struct ApplyVarMutator: public ExprMutator { + std::unordered_map vars; + ApplyVarMutator(const std::unordered_map& vars) : vars(vars) {} + Expr VisitExpr_(const CallNode* op) { + if (auto var_op = op->op.as()) { + if (vars.count(GetRef(var_op)) != 0) { + auto gv = vars[GetRef(var_op)]; + Array args = {GetRef(var_op)}; + for (auto arg: op->args) { + args.push_back(arg); + } + return Call(gv, args); + } + } + + return ExprMutator::VisitExpr_(op); + } + }; + + return ApplyVarMutator(vars).Mutate(body); + } + + void AddConstructor(GlobalTypeVar gtv, Constructor c) { + if (!mod->ContainGlobalTypeVar(gtv->name_hint)) { + mod->AddTypeDef(gtv, TypeData(gtv, {}, {c})); + } else { + auto typedata = mod->LookupTypeDef(gtv); + auto constructors = typedata->constructors; + constructors.push_back(c); + mod->UpdateTypeDef(gtv, TypeData(typedata->header, typedata->type_vars, constructors)); + } + } + + void AddApplyCase(GlobalVar gv, FuncType ft, Constructor c) { + if (!mod->ContainGlobalVar(gv->name_hint)) { + auto x = Var("x", func_encoding[ft]); + auto vars = Array({x}); + auto args = Array(); + for (auto t: ft->arg_types) { + auto y = Var("y", t); + vars.push_back(y); + args.push_back(y); + } + + + auto clauses = Array({Clause(PatternConstructor(c, {}), Call(x, args))}); + auto body = Match(x, clauses); + auto f = Function(vars, body, ft->ret_type, {}); + + mod->Add(gv, f); + } else { + auto f = Downcast(mod->Lookup(gv)); + auto body = f->body.as(); + CHECK(body) << "internal invariant broken; apply function body should be a match node"; + + auto clauses = body->clauses; + auto x = f->params[0]; + auto args = Array(); + for (size_t i = 1; i < f->params.size(); i++) { + args.push_back(f->params[i]); + } + clauses.push_back(Clause(PatternConstructor(c, {}), Call(x, args))); + + mod->Add(gv, Function(f->params, Match(x, clauses), f->ret_type, f->type_params), true); + } + } + + Expr EncodeFunctionArg(const Expr& f, const FuncTypeNode* ft) { + if (func_encoding.count(GetRef(ft)) == 0) { + func_encoding[GetRef(ft)] = GlobalTypeVar("T" + TypeToString(ft), TypeKind::kAdtHandle); + } + + auto gtv = func_encoding[GetRef(ft)]; + auto c = Constructor(std::to_string(constructor_name++), {}, gtv); + AddConstructor(gtv, c); + + if (apply_map.count(GetRef(ft)) == 0) { + apply_map[GetRef(ft)] = GlobalVar("apply" + TypeToString(ft)); + } + + auto gv = apply_map[GetRef(ft)]; + AddApplyCase(gv, GetRef(ft), c); + + return Call(c, {}); + } + + std::string TypeToString(const TypeNode* t) { + std::ostringstream s; + s << t; + return s.str(); + } + + Expr Clone(const FunctionNode* f, const Array type_args) { + return DeDup(Specialize(f, type_args)); + } + + Expr Specialize(const FunctionNode* f, const Array type_args) { + auto map = tvm::Map(); + for (size_t i = 0; i < type_args.size(); i++) { + map.Set(f->type_params[i], type_args[i]); + } + return TypeSubst(GetRef(f), map); + } +}; + +Expr Defunctionalization(const Expr& e, const IRModule& mod) { + auto f = e.as(); + CHECK(f) << "input need to be a function"; + CHECK(f->type_params.size() == 0) << "no polymorphism supported for defunctionalization"; + for (const auto& p : f->params) { + CHECK(!HasFuncType(p)) << "input parameters cannot have func type"; + } + + return DefuncMutator(mod).VisitExpr(e); +} + +TVM_REGISTER_GLOBAL("relay._transform.Defunctionalization").set_body_typed(Defunctionalization); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_defunctionalization.py b/tests/python/relay/test_pass_defunctionalization.py new file mode 100644 index 000000000000..d73433c7f32c --- /dev/null +++ b/tests/python/relay/test_pass_defunctionalization.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import relay +from tvm.relay.transform import Defunctionalization, InferType + +def test_local_simple(): + code = """ +#[version = "0.0.5"] +def @main(%l: float32) -> float32 { + %0 = fn[A, B](%f: fn(A) -> B, %xs: A) -> B { + %f(%xs) + }; + %1 = fn[A](%x: A) -> A { + %x + }; + %0(%1, %l) +} +""" + mod = tvm.parser.fromtext(code) + mod = InferType()(mod) + expr = Defunctionalization(mod['main'], mod) + +if __name__ == "__main__": + # pytest.main([__file__]) + test_local_simple() \ No newline at end of file From c936da69203ba23ca4c37af58bb014ba89bbdf1a Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Tue, 1 Sep 2020 20:17:07 -0700 Subject: [PATCH 02/17] working on type arg infer --- src/relay/analysis/util.cc | 15 + src/relay/transforms/defunctionalization.cc | 280 +++++++++++++++--- src/relay/transforms/gradient.cc | 16 - src/relay/transforms/pass_util.h | 4 + .../relay/test_pass_defunctionalization.py | 39 ++- 5 files changed, 296 insertions(+), 58 deletions(-) diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index b98106a091b3..ef25d0e54add 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -449,6 +449,21 @@ Expr TypeSubst(const Expr& expr, const tvm::Map& subst_map) { return ret; } +Expr DeGlobal(const Optional& mod, const Expr& e) { + const auto* x = e.as(); + + if (mod.defined() && x) { + BaseFunc base_func = mod.value()->Lookup(GetRef(x)); + if (auto* n = base_func.as()) { + return GetRef(n); + } else { + return e; + } + } else { + return e; + } +} + struct IsDynamicVisitor : public TypeVisitor { bool is_dyn{false}; void VisitType_(const TensorTypeNode* tt) { diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index bc532d9e4c88..a4de180cfd03 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -31,6 +31,7 @@ #include #include +#include "../analysis/type_solver.h" #include "../transforms/pass_util.h" namespace tvm { namespace relay { @@ -62,21 +63,203 @@ bool IsHigherOrderFunc(const FuncType& t) { return higher_order |= HasFuncType(t->ret_type); } +Array InferTypeArgs(const CallNode* call, const IRModule& mod) { + // struct InferTypeArgsVisitor: public TypeFunctor { + // std::unordered_map typearg_map; + // std::unordered_set type_args; + // std::unordered_set poly_arg_type_params; + // InferTypeArgsVisitor(const std::unordered_set& type_args) : type_args(type_args) {} + // void VisitType_(const TypeVarNode* t, const Type& b) { + // if (auto tv = b.as()) { + // if (poly_arg_type_params.count(GetRef(tv)) > 0) { + // return; + // } + // } + + // auto tv = GetRef(t); + // if (type_args.count(tv) > 0) { + // if (typearg_map.count(tv) > 0) { + // std::cout << "L: " << typearg_map[tv] << std::endl; + // std::cout << "R: " << b << std::endl; + // CHECK(StructuralEqual()(typearg_map[tv], b)) << "failed to infer type args"; + // } + // typearg_map[tv] = b; + // } + // } + + // void VisitType_(const TensorTypeNode* t, const Type& b) {} + // void VisitType_(const TypeConstraintNode* t, const Type& b) {} + // void VisitType_(const FuncTypeNode* t, const Type& b) { + // if (auto tv = b.as()) { + // if (poly_arg_type_params.count(GetRef(tv)) > 0) { + // return; + // } + // } + + // auto fty = b.as(); + // CHECK(fty) << "expected func type when infering type args"; + // CHECK(t->arg_types.size() == fty->arg_types.size()) << "incorrect number of args when infering type args"; + + // if (fty->type_params.size() > 0) { + // for (auto t: fty->type_params) { + // poly_arg_type_params.insert(t); + // } + // } + + // for (size_t i = 0; i < t->arg_types.size(); i++) { + // this->VisitType(t->arg_types[i], fty->arg_types[i]); + // } + // this->VisitType(t->ret_type, fty->ret_type); + // } + // void VisitType_(const TupleTypeNode* t, const Type& b) { + // if (auto tv = b.as()) { + // if (poly_arg_type_params.count(GetRef(tv)) > 0) { + // return; + // } + // } + + // auto ty = b.as(); + // CHECK(ty) << "expected tuple type when infering type args"; + // CHECK(t->fields.size() == ty->fields.size()) << "incorrect tuple size when infering type args"; + // for (size_t i = 0; i < t->fields.size(); i++) { + // this->VisitType(t->fields[i], ty->fields[i]); + // } + // } + // void VisitType_(const TypeRelationNode* t, const Type& b) {} + // void VisitType_(const IncompleteTypeNode* t, const Type& b) { + // CHECK(false) << "encountered incompletetype when inferring type args"; + // } + // void VisitType_(const RelayRefTypeNode* t, const Type& b) { + // if (auto tv = b.as()) { + // if (poly_arg_type_params.count(GetRef(tv)) > 0) { + // return; + // } + // } + + // auto ty = b.as(); + // CHECK(ty) << "expected ref type when infering type args"; + // this->VisitType(t->value, ty->value); + // } + // void VisitType_(const GlobalTypeVarNode* t, const Type& b) { + // } + // void VisitType_(const TypeCallNode* t, const Type& b) { + // if (auto tv = b.as()) { + // if (poly_arg_type_params.count(GetRef(tv)) > 0) { + // return; + // } + // } + + // auto ty = b.as(); + // CHECK(ty) << "expected tuple type when infering type args"; + // CHECK(t->args.size() == ty->args.size()) << "incorrect tuple size when infering type args"; + // for (size_t i = 0; i < t->args.size(); i++) { + // this->VisitType(t->args[i], ty->args[i]); + // } + // } + // void VisitType_(const TypeDataNode* t, const Type& b) {} + // void VisitType_(const PrimTypeNode* t, const Type& b) {} + // void VisitType_(const PointerTypeNode* t, const Type& b) {} + // }; + + // std::unordered_set type_args; + // for (auto tv: f->type_params) { type_args.insert(tv); } + // auto itav = InferTypeArgsVisitor(type_args); + // for (size_t i = 0; i < f->params.size(); i++) { + // itav.VisitType(f->params[i]->checked_type(), args[i]->checked_type()); + // } + + // Array typeargs; + // for (auto tv: f->type_params) { typeargs.push_back(itav.typearg_map[tv]); + // std::cout<< "resolved type" << itav.typearg_map[tv] << std::endl; + // } + + // return typeargs; + std::cout << "START" << std::endl; + ErrorReporter err; + TypeSolver solver(mod->GetGlobalVar("main"), mod, &err); + const FuncTypeNode* fn_ty = call->op->checked_type().as(); + + tvm::Map subst_map; + for (auto& tv: fn_ty->type_params) { + subst_map.Set(tv, IncompleteType(Kind::kType)); + } + + auto inst_fnty = FuncType(fn_ty->arg_types, fn_ty->ret_type, {}, {}); + auto f_incomplete = Downcast(Bind(inst_fnty, subst_map)); + std::cout << f_incomplete << std::endl; + // Array arg_types; + + // for (auto t: call->args) { + // auto ty = t->checked_type(); + // auto bound = BoundTypeVars(ty, mod); + // for (auto tv: bound) { + // subst_map.Set(tv, IncompleteType(Kind::kType)); + // } + // if (auto fn_ty = ty.as()) { + // arg_types.push_back(TypeSubst(FuncType(fn_ty->arg_types, fn_ty->ret_type, {}, fn_ty->type_constraints), subst_map)); + // } else { + // arg_types.push_back(TypeSubst(ty, subst_map)); + // } + // } + std::cout << "REACHED" << std::endl; + // CHECK(arg_types.size() == f_incomplete->arg_types.size()) << "num of arguments does not match expected"; + size_t num_args = f_incomplete->arg_types.size(); + // for (size_t i = 0; i < num_args; i++) { + // std::cout << "size: " << num_args << "; i: " << i << std::endl; + // // auto t1 = f_incomplete->arg_types[i]; + // // auto t2 = call->args[i]->checked_type(); + // // std::cout << "l: "<< t1 << std::endl; + // // std::cout << "r: "<(call)); + // // std::cout << "Univifed: " << t << std::endl; + // } + for (size_t i = 0; i < num_args; i++) { + std::cout << "i: " << i << "; num_args: " << num_args << std::endl; + } + + // for (size_t i = 0; i < f_incomplete->arg_types.size(); i++) { + // std::cout << "size: " << f_incomplete->arg_types.size() << "; i: " << i << std::endl; + // auto t1 = f_incomplete->arg_types[i]; + // auto t2 = arg_types[i]; + // std::cout << "l: " << t1 << " r: " << t2 << std::endl; + + // try { + // auto t = solver.Unify(t1, t2, GetRef(call)); + // std::cout << "Univifed: " << t << std::endl; + // } catch (const dmlc::Error& e) { + // CHECK(false) << "Error unifying `" << t1 << "` and `" << t2 << "`: " << e.what(); + // } + // } + + // for (auto& tv: fn_ty->type_params) { + // std::cout << "Resolved: " << solver.Resolve(subst_map[tv]); + // } +} + class DefuncMutator : public ExprMutator { public: - DefuncMutator(const IRModule& mod) : mod(mod), constructor_name(0) {} + DefuncMutator(const IRModule& mod) : mod(mod), constructor_name(0), anon_name(0) {} Expr VisitExpr_(const CallNode* op) { auto op_func = op->op; - auto f = op_func.as(); - std::cout << op_func << std::endl; - CHECK(f) << "only calls to functions are supported so far"; - + auto f = DeGlobal(mod, op_func).as(); + CHECK(f) << "only calls to functions or globalvars are supported so far"; + // CHECK(op->type_args.size() == f->type_params.size()) << "all type args must be explicit"; + // clone function and specialize if there are higher order functions - if (IsHigherOrderFunc(Downcast(f->checked_type()))) { - auto f_clone = Downcast(Clone(f, op->type_args)); + if (IsHigherOrderFunc(Downcast(f->func_type_annotation()))) { + std::cout << "Call Function: " << GetRef(f) << std::endl; + + std::string name; + if (auto gv = op->op.as()) { + name = gv->name_hint; + } else { + name = "anon" + std::to_string(anon_name++); + } + auto clone_gv = Clone(name, f, InferTypeArgs(op, mod)); + auto f_clone = Downcast(DeGlobal(mod, clone_gv)); std::cout << f_clone << std::endl; - auto f_clone_type = Downcast(f_clone->checked_type()); + auto f_clone_type = f_clone->func_type_annotation(); CHECK(FreeTypeVars(f_clone_type, mod).size() == 0) << "free type vars in specialized function"; CHECK(FreeVars(f_clone).size() == FreeVars(GetRef(f)).size()) @@ -90,17 +273,17 @@ class DefuncMutator : public ExprMutator { auto arg = EncodeFunctionArg(op->args[i], f_clone_type->arg_types[i].as()); args.push_back(arg); applyVars[f_clone->params[i]] = apply_map[f_clone->params[i]->checked_type()]; + } else { + CHECK(!HasFuncType(f_clone_type->arg_types[i])) << "nested function type in parameter not supported yet"; + args.push_back(op->args[i]); } - - CHECK(!HasFuncType(f_clone_type->arg_types[i])) << "nested function type in parameter not supported yet"; - args.push_back(op->args[i]); } + std::cout << "reached here" << std::endl; + auto new_func = ApplyVars(clone_gv, applyVars); - auto new_func = ApplyVars(f_clone, applyVars); - - return Call(ExprMutator::VisitExpr(new_func), args); + return Call(new_func, args); } - return ExprMutator::VisitExpr(GetRef(op)); + return ExprMutator::VisitExpr_(op); } // Expr VisitExpr_(const LetNode* op) { @@ -115,10 +298,6 @@ class DefuncMutator : public ExprMutator { // return GetRef(op); // } - Expr VisitExpr_(const GlobalVarNode* op) { CHECK(false) << "global var not supported yet"; - throw std::runtime_error("GlobalVar not supported"); - } - private: IRModule mod; // encode func type to ADT @@ -126,11 +305,13 @@ class DefuncMutator : public ExprMutator { std::unordered_map apply_map; // use monotonically increasing integer to represent new constructor_name unsigned int constructor_name; + unsigned int anon_name; - Expr ApplyVars(Expr body, const std::unordered_map& vars) { + Expr ApplyVars(GlobalVar gv, const std::unordered_map& vars) { struct ApplyVarMutator: public ExprMutator { std::unordered_map vars; - ApplyVarMutator(const std::unordered_map& vars) : vars(vars) {} + std::unordered_map var_map; + ApplyVarMutator(const std::unordered_map& vars, const std::unordered_map& var_map) : vars(vars), var_map(var_map) {} Expr VisitExpr_(const CallNode* op) { if (auto var_op = op->op.as()) { if (vars.count(GetRef(var_op)) != 0) { @@ -139,15 +320,34 @@ class DefuncMutator : public ExprMutator { for (auto arg: op->args) { args.push_back(arg); } - return Call(gv, args); + return ExprMutator::VisitExpr_(Call(gv, args).as()); } } return ExprMutator::VisitExpr_(op); } + + Expr VisitExpr_(const VarNode* op) { + auto var = GetRef(op); + if (var_map.count(var) != 0) { + return var_map[var]; + } + return ExprMutator::VisitExpr_(op); + } }; + auto e = Downcast(mod->Lookup(gv)); - return ApplyVarMutator(vars).Mutate(body); + std::unordered_map var_map; + for (auto v : e->params) { + if (v->type_annotation.as()) { + var_map[v] = Var(v->name_hint(), IncompleteType(TypeKind::kType)); + } + } + auto applied = Downcast(ApplyVarMutator(vars, var_map).Mutate(e)); + auto typed = this->VisitExpr(InferType(applied, mod, gv)); + mod->Add(gv, Downcast(typed), true); + + return gv; } void AddConstructor(GlobalTypeVar gtv, Constructor c) { @@ -161,9 +361,9 @@ class DefuncMutator : public ExprMutator { } } - void AddApplyCase(GlobalVar gv, FuncType ft, Constructor c) { + void AddApplyCase(GlobalVar gv, FuncType ft, Constructor c, const Expr& expr) { if (!mod->ContainGlobalVar(gv->name_hint)) { - auto x = Var("x", func_encoding[ft]); + auto x = Var("x", TypeCall(c->belong_to, {})); auto vars = Array({x}); auto args = Array(); for (auto t: ft->arg_types) { @@ -172,8 +372,7 @@ class DefuncMutator : public ExprMutator { args.push_back(y); } - - auto clauses = Array({Clause(PatternConstructor(c, {}), Call(x, args))}); + auto clauses = Array({Clause(PatternConstructor(c, {}), Call(expr, args))}); auto body = Match(x, clauses); auto f = Function(vars, body, ft->ret_type, {}); @@ -189,15 +388,16 @@ class DefuncMutator : public ExprMutator { for (size_t i = 1; i < f->params.size(); i++) { args.push_back(f->params[i]); } - clauses.push_back(Clause(PatternConstructor(c, {}), Call(x, args))); + clauses.push_back(Clause(PatternConstructor(c, {}), Call(expr, args))); mod->Add(gv, Function(f->params, Match(x, clauses), f->ret_type, f->type_params), true); } } Expr EncodeFunctionArg(const Expr& f, const FuncTypeNode* ft) { + auto adt_name = "T" + TypeToString(ft); if (func_encoding.count(GetRef(ft)) == 0) { - func_encoding[GetRef(ft)] = GlobalTypeVar("T" + TypeToString(ft), TypeKind::kAdtHandle); + func_encoding[GetRef(ft)] = GlobalTypeVar(adt_name, TypeKind::kAdtHandle); } auto gtv = func_encoding[GetRef(ft)]; @@ -209,27 +409,37 @@ class DefuncMutator : public ExprMutator { } auto gv = apply_map[GetRef(ft)]; - AddApplyCase(gv, GetRef(ft), c); + AddApplyCase(gv, GetRef(ft), c, f); return Call(c, {}); } std::string TypeToString(const TypeNode* t) { std::ostringstream s; - s << t; + s << GetRef(t); return s.str(); } - Expr Clone(const FunctionNode* f, const Array type_args) { - return DeDup(Specialize(f, type_args)); + GlobalVar Clone(std::string name_prefix, const FunctionNode* f, const Array type_args) { + auto spec = Specialize(f, type_args); + auto gv_name = name_prefix + TypeToString(spec->func_type_annotation().as()); + std::cout << gv_name << std::endl; + if (mod->ContainGlobalVar(gv_name)) { + return mod->GetGlobalVar(gv_name); + } + auto gv = GlobalVar(gv_name); + mod->Add(gv, Downcast(DeDup(spec))); + return gv; } - Expr Specialize(const FunctionNode* f, const Array type_args) { + Function Specialize(const FunctionNode* f, const Array type_args) { auto map = tvm::Map(); for (size_t i = 0; i < type_args.size(); i++) { map.Set(f->type_params[i], type_args[i]); } - return TypeSubst(GetRef(f), map); + // copy with typevars removed + auto copy = TypeSubst(Function(f->params, f->body, f->ret_type, {}), map); + return Downcast(copy); } }; diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index 9c472542cc91..81ce64afe7a5 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -78,22 +78,6 @@ Type WithGradientType(const Type& t) { return FuncType(ty->arg_types, TupleType({ty->ret_type, TupleType(ty->arg_types)}), {}, {}); } -//! \brief if the expression is a GlobalVar, transform to it's expression. -Expr DeGlobal(const Optional& mod, const Expr& e) { - const auto* x = e.as(); - - if (mod.defined() && x) { - BaseFunc base_func = mod.value()->Lookup(GetRef(x)); - if (auto* n = base_func.as()) { - return GetRef(n); - } else { - return e; - } - } else { - return e; - } -} - /*! \brief A fragment of the program being built by the automatic differentation * pass. */ diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h index f3c99ccfa120..640ed788894f 100644 --- a/src/relay/transforms/pass_util.h +++ b/src/relay/transforms/pass_util.h @@ -222,6 +222,10 @@ std::pair CalcScope(const DependencyGraph& dg); */ Scope LCA(Scope lhs, Scope rhs); +/*! \brief if the expression is a GlobalVar, transform to it's expression. +*/ +Expr DeGlobal(const Optional& mod, const Expr& e); + /* Special care is needed to handle local recursion. * Fill additionally take a (possibly null) Var argument, * If it is not null, Fill is required to bind the transformed result to that var. diff --git a/tests/python/relay/test_pass_defunctionalization.py b/tests/python/relay/test_pass_defunctionalization.py index d73433c7f32c..9821074b240d 100644 --- a/tests/python/relay/test_pass_defunctionalization.py +++ b/tests/python/relay/test_pass_defunctionalization.py @@ -20,23 +20,48 @@ from tvm import relay from tvm.relay.transform import Defunctionalization, InferType -def test_local_simple(): +def test_simple(): code = """ #[version = "0.0.5"] +def @apply[A, B](%f: fn(A) -> B, %xs: A) -> B { + %f(%xs) +} def @main(%l: float32) -> float32 { - %0 = fn[A, B](%f: fn(A) -> B, %xs: A) -> B { - %f(%xs) - }; - %1 = fn[A](%x: A) -> A { + %0 = fn[A](%x: A) -> A { %x }; - %0(%1, %l) + @apply(%0, %l) } """ mod = tvm.parser.fromtext(code) mod = InferType()(mod) expr = Defunctionalization(mod['main'], mod) +def test_global_recursion(): + code = """ +#[version = "0.0.5"] +type List[A] { + Cons(A, List[A]), + Nil, +} +def @id[A](%x: A) -> A { + %x +} +def @map[A, B](%f: fn(A) -> B, %xs: List[A]) -> List[B] { + match (%xs) { + Cons(%x, %rest) => Cons(%f(%x), @map(%f, %rest)), + Nil => Nil, + } +} +def @main(%l: List[float32]) -> List[float32] { + @map(@id, %l) +} +""" + mod = tvm.parser.fromtext(code) + mod = InferType()(mod) + expr = Defunctionalization(mod['main'], mod) + if __name__ == "__main__": # pytest.main([__file__]) - test_local_simple() \ No newline at end of file + test_simple() + test_global_recursion() \ No newline at end of file From 86358c2a2a586ca9bf11ab5064c457e82b8d8135 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 2 Sep 2020 10:31:27 -0700 Subject: [PATCH 03/17] fix type arg infer --- src/relay/transforms/type_infer.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index e110737d6226..ee0c6268ebc2 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -659,6 +659,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { } if (need_update_call) { + std::cout << new_e << std::endl; new_call->type_args = it->second.type_args; for (size_t i = 0; i < new_call->type_args.size(); i++) { new_call->type_args.Set(i, solver_->Resolve(new_call->type_args[i])); From 29fa2c86cd282d0d5590253e0dcc976db937f091 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 2 Sep 2020 10:32:07 -0700 Subject: [PATCH 04/17] WIP --- src/relay/transforms/defunctionalization.cc | 171 ++---------------- .../relay/test_pass_defunctionalization.py | 45 ++++- 2 files changed, 55 insertions(+), 161 deletions(-) diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index a4de180cfd03..8bad4bf09499 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -64,117 +64,6 @@ bool IsHigherOrderFunc(const FuncType& t) { } Array InferTypeArgs(const CallNode* call, const IRModule& mod) { - // struct InferTypeArgsVisitor: public TypeFunctor { - // std::unordered_map typearg_map; - // std::unordered_set type_args; - // std::unordered_set poly_arg_type_params; - // InferTypeArgsVisitor(const std::unordered_set& type_args) : type_args(type_args) {} - // void VisitType_(const TypeVarNode* t, const Type& b) { - // if (auto tv = b.as()) { - // if (poly_arg_type_params.count(GetRef(tv)) > 0) { - // return; - // } - // } - - // auto tv = GetRef(t); - // if (type_args.count(tv) > 0) { - // if (typearg_map.count(tv) > 0) { - // std::cout << "L: " << typearg_map[tv] << std::endl; - // std::cout << "R: " << b << std::endl; - // CHECK(StructuralEqual()(typearg_map[tv], b)) << "failed to infer type args"; - // } - // typearg_map[tv] = b; - // } - // } - - // void VisitType_(const TensorTypeNode* t, const Type& b) {} - // void VisitType_(const TypeConstraintNode* t, const Type& b) {} - // void VisitType_(const FuncTypeNode* t, const Type& b) { - // if (auto tv = b.as()) { - // if (poly_arg_type_params.count(GetRef(tv)) > 0) { - // return; - // } - // } - - // auto fty = b.as(); - // CHECK(fty) << "expected func type when infering type args"; - // CHECK(t->arg_types.size() == fty->arg_types.size()) << "incorrect number of args when infering type args"; - - // if (fty->type_params.size() > 0) { - // for (auto t: fty->type_params) { - // poly_arg_type_params.insert(t); - // } - // } - - // for (size_t i = 0; i < t->arg_types.size(); i++) { - // this->VisitType(t->arg_types[i], fty->arg_types[i]); - // } - // this->VisitType(t->ret_type, fty->ret_type); - // } - // void VisitType_(const TupleTypeNode* t, const Type& b) { - // if (auto tv = b.as()) { - // if (poly_arg_type_params.count(GetRef(tv)) > 0) { - // return; - // } - // } - - // auto ty = b.as(); - // CHECK(ty) << "expected tuple type when infering type args"; - // CHECK(t->fields.size() == ty->fields.size()) << "incorrect tuple size when infering type args"; - // for (size_t i = 0; i < t->fields.size(); i++) { - // this->VisitType(t->fields[i], ty->fields[i]); - // } - // } - // void VisitType_(const TypeRelationNode* t, const Type& b) {} - // void VisitType_(const IncompleteTypeNode* t, const Type& b) { - // CHECK(false) << "encountered incompletetype when inferring type args"; - // } - // void VisitType_(const RelayRefTypeNode* t, const Type& b) { - // if (auto tv = b.as()) { - // if (poly_arg_type_params.count(GetRef(tv)) > 0) { - // return; - // } - // } - - // auto ty = b.as(); - // CHECK(ty) << "expected ref type when infering type args"; - // this->VisitType(t->value, ty->value); - // } - // void VisitType_(const GlobalTypeVarNode* t, const Type& b) { - // } - // void VisitType_(const TypeCallNode* t, const Type& b) { - // if (auto tv = b.as()) { - // if (poly_arg_type_params.count(GetRef(tv)) > 0) { - // return; - // } - // } - - // auto ty = b.as(); - // CHECK(ty) << "expected tuple type when infering type args"; - // CHECK(t->args.size() == ty->args.size()) << "incorrect tuple size when infering type args"; - // for (size_t i = 0; i < t->args.size(); i++) { - // this->VisitType(t->args[i], ty->args[i]); - // } - // } - // void VisitType_(const TypeDataNode* t, const Type& b) {} - // void VisitType_(const PrimTypeNode* t, const Type& b) {} - // void VisitType_(const PointerTypeNode* t, const Type& b) {} - // }; - - // std::unordered_set type_args; - // for (auto tv: f->type_params) { type_args.insert(tv); } - // auto itav = InferTypeArgsVisitor(type_args); - // for (size_t i = 0; i < f->params.size(); i++) { - // itav.VisitType(f->params[i]->checked_type(), args[i]->checked_type()); - // } - - // Array typeargs; - // for (auto tv: f->type_params) { typeargs.push_back(itav.typearg_map[tv]); - // std::cout<< "resolved type" << itav.typearg_map[tv] << std::endl; - // } - - // return typeargs; - std::cout << "START" << std::endl; ErrorReporter err; TypeSolver solver(mod->GetGlobalVar("main"), mod, &err); const FuncTypeNode* fn_ty = call->op->checked_type().as(); @@ -186,54 +75,20 @@ Array InferTypeArgs(const CallNode* call, const IRModule& mod) { auto inst_fnty = FuncType(fn_ty->arg_types, fn_ty->ret_type, {}, {}); auto f_incomplete = Downcast(Bind(inst_fnty, subst_map)); - std::cout << f_incomplete << std::endl; - // Array arg_types; - - // for (auto t: call->args) { - // auto ty = t->checked_type(); - // auto bound = BoundTypeVars(ty, mod); - // for (auto tv: bound) { - // subst_map.Set(tv, IncompleteType(Kind::kType)); - // } - // if (auto fn_ty = ty.as()) { - // arg_types.push_back(TypeSubst(FuncType(fn_ty->arg_types, fn_ty->ret_type, {}, fn_ty->type_constraints), subst_map)); - // } else { - // arg_types.push_back(TypeSubst(ty, subst_map)); - // } - // } - std::cout << "REACHED" << std::endl; - // CHECK(arg_types.size() == f_incomplete->arg_types.size()) << "num of arguments does not match expected"; + + CHECK(call->args.size() == f_incomplete->arg_types.size()) << "num of arguments does not match expected"; size_t num_args = f_incomplete->arg_types.size(); - // for (size_t i = 0; i < num_args; i++) { - // std::cout << "size: " << num_args << "; i: " << i << std::endl; - // // auto t1 = f_incomplete->arg_types[i]; - // // auto t2 = call->args[i]->checked_type(); - // // std::cout << "l: "<< t1 << std::endl; - // // std::cout << "r: "<(call)); - // // std::cout << "Univifed: " << t << std::endl; - // } for (size_t i = 0; i < num_args; i++) { - std::cout << "i: " << i << "; num_args: " << num_args << std::endl; + auto t1 = f_incomplete->arg_types[i]; + auto t2 = call->args[i]->checked_type(); + auto t = solver.Unify(t1, t2, GetRef(call)); } - - // for (size_t i = 0; i < f_incomplete->arg_types.size(); i++) { - // std::cout << "size: " << f_incomplete->arg_types.size() << "; i: " << i << std::endl; - // auto t1 = f_incomplete->arg_types[i]; - // auto t2 = arg_types[i]; - // std::cout << "l: " << t1 << " r: " << t2 << std::endl; - - // try { - // auto t = solver.Unify(t1, t2, GetRef(call)); - // std::cout << "Univifed: " << t << std::endl; - // } catch (const dmlc::Error& e) { - // CHECK(false) << "Error unifying `" << t1 << "` and `" << t2 << "`: " << e.what(); - // } - // } - - // for (auto& tv: fn_ty->type_params) { - // std::cout << "Resolved: " << solver.Resolve(subst_map[tv]); - // } + Array ret; + for (auto& tv: fn_ty->type_params) { + std::cout << "Resolved Type: " << solver.Resolve(subst_map[tv]) << std::endl; + ret.push_back(solver.Resolve(subst_map[tv])); + } + return ret; } class DefuncMutator : public ExprMutator { @@ -278,7 +133,6 @@ class DefuncMutator : public ExprMutator { args.push_back(op->args[i]); } } - std::cout << "reached here" << std::endl; auto new_func = ApplyVars(clone_gv, applyVars); return Call(new_func, args); @@ -322,7 +176,7 @@ class DefuncMutator : public ExprMutator { } return ExprMutator::VisitExpr_(Call(gv, args).as()); } - } + } else if (IsHigherOrderFunc(Downcast(op->op->checked_type()))) return ExprMutator::VisitExpr_(op); } @@ -345,6 +199,7 @@ class DefuncMutator : public ExprMutator { } auto applied = Downcast(ApplyVarMutator(vars, var_map).Mutate(e)); auto typed = this->VisitExpr(InferType(applied, mod, gv)); + std::cout << "TYPED: " << typed << std::endl; mod->Add(gv, Downcast(typed), true); return gv; diff --git a/tests/python/relay/test_pass_defunctionalization.py b/tests/python/relay/test_pass_defunctionalization.py index 9821074b240d..dbe5f45015eb 100644 --- a/tests/python/relay/test_pass_defunctionalization.py +++ b/tests/python/relay/test_pass_defunctionalization.py @@ -18,7 +18,7 @@ import tvm from tvm import relay -from tvm.relay.transform import Defunctionalization, InferType +from tvm.relay.transform import Defunctionalization, InferType, LambdaLift def test_simple(): code = """ @@ -34,6 +34,7 @@ def @main(%l: float32) -> float32 { } """ mod = tvm.parser.fromtext(code) + mod = LambdaLift()(mod) mod = InferType()(mod) expr = Defunctionalization(mod['main'], mod) @@ -58,10 +59,48 @@ def @main(%l: List[float32]) -> List[float32] { } """ mod = tvm.parser.fromtext(code) + mod = LambdaLift()(mod) mod = InferType()(mod) expr = Defunctionalization(mod['main'], mod) +def test_sum(): + code = """ +#[version = "0.0.5"] +type List[A] { + Cons(A, List[A]), + Nil, +} +def @main(%f: fn(int32) -> int32, %xs: List[int32]) -> int32 { + match (%xs) { + Cons(%x, %rest) => %0 = fn(%n) { + %x + %f(%n) + }; + @main(%0, %rest), + Nil => %f(0), + } +} +""" + mod = tvm.parser.fromtext(code) + mod = LambdaLift()(mod) + mod = InferType()(mod) + print(mod) + +def test(): + code = """ +#[version = "0.0.5"] +def @id[A](%x: A) -> A { + %x +} +def @main(%f: float32) -> float32 { + @id(%f) +} +""" + mod = tvm.parser.fromtext(code) + mod = InferType()(mod) + print(mod['main'].body.type_args) + if __name__ == "__main__": # pytest.main([__file__]) - test_simple() - test_global_recursion() \ No newline at end of file + # test_simple() + # test_global_recursion() + test() \ No newline at end of file From 0d46c277a1e7dfb236cba8638e6164846d16d255 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 2 Sep 2020 17:17:34 -0700 Subject: [PATCH 05/17] wip --- src/relay/transforms/defunctionalization.cc | 96 ++++++++++--------- .../relay/test_pass_defunctionalization.py | 2 +- 2 files changed, 52 insertions(+), 46 deletions(-) diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index 8bad4bf09499..162034a2d6a6 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -63,33 +63,33 @@ bool IsHigherOrderFunc(const FuncType& t) { return higher_order |= HasFuncType(t->ret_type); } -Array InferTypeArgs(const CallNode* call, const IRModule& mod) { - ErrorReporter err; - TypeSolver solver(mod->GetGlobalVar("main"), mod, &err); - const FuncTypeNode* fn_ty = call->op->checked_type().as(); - - tvm::Map subst_map; - for (auto& tv: fn_ty->type_params) { - subst_map.Set(tv, IncompleteType(Kind::kType)); - } - - auto inst_fnty = FuncType(fn_ty->arg_types, fn_ty->ret_type, {}, {}); - auto f_incomplete = Downcast(Bind(inst_fnty, subst_map)); - - CHECK(call->args.size() == f_incomplete->arg_types.size()) << "num of arguments does not match expected"; - size_t num_args = f_incomplete->arg_types.size(); - for (size_t i = 0; i < num_args; i++) { - auto t1 = f_incomplete->arg_types[i]; - auto t2 = call->args[i]->checked_type(); - auto t = solver.Unify(t1, t2, GetRef(call)); - } - Array ret; - for (auto& tv: fn_ty->type_params) { - std::cout << "Resolved Type: " << solver.Resolve(subst_map[tv]) << std::endl; - ret.push_back(solver.Resolve(subst_map[tv])); - } - return ret; -} +// Array InferTypeArgs(const CallNode* call, const IRModule& mod) { +// ErrorReporter err; +// TypeSolver solver(mod->GetGlobalVar("main"), mod, &err); +// const FuncTypeNode* fn_ty = call->op->checked_type().as(); + +// tvm::Map subst_map; +// for (auto& tv: fn_ty->type_params) { +// subst_map.Set(tv, IncompleteType(Kind::kType)); +// } + +// auto inst_fnty = FuncType(fn_ty->arg_types, fn_ty->ret_type, {}, {}); +// auto f_incomplete = Downcast(Bind(inst_fnty, subst_map)); + +// CHECK(call->args.size() == f_incomplete->arg_types.size()) << "num of arguments does not match expected"; +// size_t num_args = f_incomplete->arg_types.size(); +// for (size_t i = 0; i < num_args; i++) { +// auto t1 = f_incomplete->arg_types[i]; +// auto t2 = call->args[i]->checked_type(); +// auto t = solver.Unify(t1, t2, GetRef(call)); +// } +// Array ret; +// for (auto& tv: fn_ty->type_params) { +// std::cout << "Resolved Type: " << solver.Resolve(subst_map[tv]) << std::endl; +// ret.push_back(solver.Resolve(subst_map[tv])); +// } +// return ret; +// } class DefuncMutator : public ExprMutator { public: @@ -97,24 +97,26 @@ class DefuncMutator : public ExprMutator { Expr VisitExpr_(const CallNode* op) { auto op_func = op->op; - auto f = DeGlobal(mod, op_func).as(); - CHECK(f) << "only calls to functions or globalvars are supported so far"; - // CHECK(op->type_args.size() == f->type_params.size()) << "all type args must be explicit"; - - // clone function and specialize if there are higher order functions - if (IsHigherOrderFunc(Downcast(f->func_type_annotation()))) { - std::cout << "Call Function: " << GetRef(f) << std::endl; + CHECK(op->type_args.size() == f->type_params.size()) << "all type args must be explicit"; + // clone function and specialize if it is a higher order functions + auto op_type = InstFuncType(op_func->checked_type().as(), op->type_args); + if (IsHigherOrderFunc(op_type)) { std::string name; if (auto gv = op->op.as()) { name = gv->name_hint; } else { name = "anon" + std::to_string(anon_name++); } - auto clone_gv = Clone(name, f, InferTypeArgs(op, mod)); - auto f_clone = Downcast(DeGlobal(mod, clone_gv)); - std::cout << f_clone << std::endl; - auto f_clone_type = f_clone->func_type_annotation(); + name += TypeToString(op_type); + + auto f = DeGlobal(mod, op_func).as(); + CHECK(f) << "only calls to functions or globalvars are supported so far"; + std::cout << "Call Function: " << GetRef(f) << std::endl; + + auto gv = Clone(name, f, op->type_args); + for (op_type) + CHECK(FreeTypeVars(f_clone_type, mod).size() == 0) << "free type vars in specialized function"; CHECK(FreeVars(f_clone).size() == FreeVars(GetRef(f)).size()) @@ -275,18 +277,22 @@ class DefuncMutator : public ExprMutator { return s.str(); } - GlobalVar Clone(std::string name_prefix, const FunctionNode* f, const Array type_args) { + GlobalVar Clone(std::string name, const FunctionNode* f, const Array type_args) { auto spec = Specialize(f, type_args); - auto gv_name = name_prefix + TypeToString(spec->func_type_annotation().as()); - std::cout << gv_name << std::endl; - if (mod->ContainGlobalVar(gv_name)) { - return mod->GetGlobalVar(gv_name); - } - auto gv = GlobalVar(gv_name); + auto gv = GlobalVar(name); mod->Add(gv, Downcast(DeDup(spec))); return gv; } + FuncType InstFuncType(const FuncTypeNode* fty, const Array type_args) { + auto map = tvm::Map(); + for (size_t i = 0; i < type_args.size(); i++) { + map.Set(f->type_params[i], type_args[i]); + } + // copy with typevars removed + return TypeSubst(FuncType(fty->arg_types, fty->ret_type, {}, {}), map); + } + Function Specialize(const FunctionNode* f, const Array type_args) { auto map = tvm::Map(); for (size_t i = 0; i < type_args.size(); i++) { diff --git a/tests/python/relay/test_pass_defunctionalization.py b/tests/python/relay/test_pass_defunctionalization.py index dbe5f45015eb..b417bf4c41f1 100644 --- a/tests/python/relay/test_pass_defunctionalization.py +++ b/tests/python/relay/test_pass_defunctionalization.py @@ -92,7 +92,7 @@ def @id[A](%x: A) -> A { %x } def @main(%f: float32) -> float32 { - @id(%f) + @id(@id)(%f) } """ mod = tvm.parser.fromtext(code) From 79932d9912654bd5758973b243dd9e445aede5b1 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Thu, 3 Sep 2020 16:33:26 -0700 Subject: [PATCH 06/17] wip --- src/relay/transforms/defunctionalization.cc | 217 ++++++++---------- .../relay/test_pass_defunctionalization.py | 6 +- 2 files changed, 105 insertions(+), 118 deletions(-) diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index 162034a2d6a6..ee389e0f538d 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -95,118 +95,84 @@ class DefuncMutator : public ExprMutator { public: DefuncMutator(const IRModule& mod) : mod(mod), constructor_name(0), anon_name(0) {} - Expr VisitExpr_(const CallNode* op) { - auto op_func = op->op; - CHECK(op->type_args.size() == f->type_params.size()) << "all type args must be explicit"; - - // clone function and specialize if it is a higher order functions - auto op_type = InstFuncType(op_func->checked_type().as(), op->type_args); - if (IsHigherOrderFunc(op_type)) { - std::string name; - if (auto gv = op->op.as()) { - name = gv->name_hint; + Expr VisitExpr_(const CallNode* call) { + if (auto op = call->op.as()) { + CHECK(call->type_args.size() == op->checked_type().as()->type_params.size()) << "all type args must be explicit"; + + auto op_type = InstFuncType(op->checked_type().as(), call->type_args); + + CHECK(!HasFuncType(op_type->ret_type)) << "returning functions not supported"; + if (!IsHigherOrderFunc(op_type)) { + return ExprMutator::VisitExpr_(call); + } + auto name = op->name_hint + TypeToString(op_type); + auto gv = GlobalVar(name); + if (mod->ContainGlobalVar(name)) { + gv = mod->GetGlobalVar(name); } else { - name = "anon" + std::to_string(anon_name++); + // clone and specialize with specific type + auto clone = DeDup(DeGlobal(mod, GetRef(op))).as(); + auto specialized_function = Specialize(clone, call->type_args); + auto f = Downcast(this->VisitExpr(FirstifyVars(specialized_function))); + mod->Add(gv, f); } - name += TypeToString(op_type); - auto f = DeGlobal(mod, op_func).as(); - CHECK(f) << "only calls to functions or globalvars are supported so far"; - std::cout << "Call Function: " << GetRef(f) << std::endl; - - auto gv = Clone(name, f, op->type_args); - for (op_type) + Array args; + for (size_t i = 0; i < call->args.size(); i++) { + auto arg = call->args[i]; + auto type = op_type->arg_types[i]; + // we assume arg is either an identifier or a function + if (!HasFuncType(type)) { + args.push_back(arg); + continue; + } - CHECK(FreeTypeVars(f_clone_type, mod).size() == 0) - << "free type vars in specialized function"; - CHECK(FreeVars(f_clone).size() == FreeVars(GetRef(f)).size()) - << "local closures not supported yet"; - CHECK(!HasFuncType(f_clone_type->ret_type)) << "returning function not supported yet"; + CHECK(type.as()) << "assume no nested functions"; - Array args; - std::unordered_map applyVars; - for (size_t i = 0; i < f_clone_type->arg_types.size(); i++) { - if (f_clone_type->arg_types[i].as()) { - auto arg = EncodeFunctionArg(op->args[i], f_clone_type->arg_types[i].as()); + if (arg.as()) { args.push_back(arg); - applyVars[f_clone->params[i]] = apply_map[f_clone->params[i]->checked_type()]; - } else { - CHECK(!HasFuncType(f_clone_type->arg_types[i])) << "nested function type in parameter not supported yet"; - args.push_back(op->args[i]); } + if (arg.as()) { + args.push_back(EncodeGlobalVar(Downcast(arg), Downcast(type))); + } + if (arg.as()) { + + } + CHECK(false) << "assume all first-order-parameters are identifiers or functions"; } - auto new_func = ApplyVars(clone_gv, applyVars); - return Call(new_func, args); + } else if (auto op = call->op.as()) { + std::unordered_map var_binding_map; + for (size_t i = 0; i < op->params.size(); i++) { + var_binding_map[op->params[i]] = call->args[i]; + } + auto e = Bind(op->body, var_binding_map); + return this->VisitExpr(e); + } else if (auto op = call->op.as()) { + auto op_type = InstFuncType(var_save_type[GetRef(op)].as(), call->type_args); + + Array args = {GetRef(op)}; + for (auto arg: call->args) { + args.push_back(this->VisitExpr(arg)); + } + + auto e = Call(apply_map[op_type], args); + return e; } - return ExprMutator::VisitExpr_(op); + return ExprMutator::VisitExpr_(call); } - // Expr VisitExpr_(const LetNode* op) { - // var_map[op->var] = this->VisitExpr(op->value); - // return this->VisitExpr(op->body); - // } - - // Expr VisitExpr_(const VarNode* op) { - // if (var_map.count(GetRef(op)) != 0) { - // return var_map[GetRef(op)]; - // } - // return GetRef(op); - // } - private: IRModule mod; // encode func type to ADT std::unordered_map func_encoding; std::unordered_map apply_map; + std::unordered_map var_save_type; + std::unordered_map, ObjectHash, ObjectEqual> gv_datatype_map; // use monotonically increasing integer to represent new constructor_name unsigned int constructor_name; unsigned int anon_name; - Expr ApplyVars(GlobalVar gv, const std::unordered_map& vars) { - struct ApplyVarMutator: public ExprMutator { - std::unordered_map vars; - std::unordered_map var_map; - ApplyVarMutator(const std::unordered_map& vars, const std::unordered_map& var_map) : vars(vars), var_map(var_map) {} - Expr VisitExpr_(const CallNode* op) { - if (auto var_op = op->op.as()) { - if (vars.count(GetRef(var_op)) != 0) { - auto gv = vars[GetRef(var_op)]; - Array args = {GetRef(var_op)}; - for (auto arg: op->args) { - args.push_back(arg); - } - return ExprMutator::VisitExpr_(Call(gv, args).as()); - } - } else if (IsHigherOrderFunc(Downcast(op->op->checked_type()))) - - return ExprMutator::VisitExpr_(op); - } - - Expr VisitExpr_(const VarNode* op) { - auto var = GetRef(op); - if (var_map.count(var) != 0) { - return var_map[var]; - } - return ExprMutator::VisitExpr_(op); - } - }; - auto e = Downcast(mod->Lookup(gv)); - - std::unordered_map var_map; - for (auto v : e->params) { - if (v->type_annotation.as()) { - var_map[v] = Var(v->name_hint(), IncompleteType(TypeKind::kType)); - } - } - auto applied = Downcast(ApplyVarMutator(vars, var_map).Mutate(e)); - auto typed = this->VisitExpr(InferType(applied, mod, gv)); - std::cout << "TYPED: " << typed << std::endl; - mod->Add(gv, Downcast(typed), true); - - return gv; - } - void AddConstructor(GlobalTypeVar gtv, Constructor c) { if (!mod->ContainGlobalTypeVar(gtv->name_hint)) { mod->AddTypeDef(gtv, TypeData(gtv, {}, {c})); @@ -251,46 +217,44 @@ class DefuncMutator : public ExprMutator { } } - Expr EncodeFunctionArg(const Expr& f, const FuncTypeNode* ft) { - auto adt_name = "T" + TypeToString(ft); - if (func_encoding.count(GetRef(ft)) == 0) { - func_encoding[GetRef(ft)] = GlobalTypeVar(adt_name, TypeKind::kAdtHandle); - } + Expr EncodeGlobalVar(const GlobalVar& gv, const FuncType& ft) { + auto map = gv_datatype_map[gv]; + if (map.count(ft) == 0) { + auto adt_name = "T" + TypeToString(ft); - auto gtv = func_encoding[GetRef(ft)]; - auto c = Constructor(std::to_string(constructor_name++), {}, gtv); - AddConstructor(gtv, c); + if (func_encoding.count(ft) == 0) { + func_encoding[ft] = GlobalTypeVar(adt_name, TypeKind::kAdtHandle); + } - if (apply_map.count(GetRef(ft)) == 0) { - apply_map[GetRef(ft)] = GlobalVar("apply" + TypeToString(ft)); - } + auto gtv = func_encoding[ft]; + auto c = Constructor(std::to_string(constructor_name++), {}, gtv); + AddConstructor(gtv, c); - auto gv = apply_map[GetRef(ft)]; - AddApplyCase(gv, GetRef(ft), c, f); + if (apply_map.count(ft) == 0) { + apply_map[ft] = GlobalVar("apply" + TypeToString(ft)); + } + auto gv = apply_map[ft]; + AddApplyCase(gv, ft, c, gv); + } + + auto c = map[ft]; return Call(c, {}); } - std::string TypeToString(const TypeNode* t) { + std::string TypeToString(const Type& t) { std::ostringstream s; - s << GetRef(t); + s << t; return s.str(); } - GlobalVar Clone(std::string name, const FunctionNode* f, const Array type_args) { - auto spec = Specialize(f, type_args); - auto gv = GlobalVar(name); - mod->Add(gv, Downcast(DeDup(spec))); - return gv; - } - FuncType InstFuncType(const FuncTypeNode* fty, const Array type_args) { auto map = tvm::Map(); for (size_t i = 0; i < type_args.size(); i++) { - map.Set(f->type_params[i], type_args[i]); + map.Set(fty->type_params[i], type_args[i]); } // copy with typevars removed - return TypeSubst(FuncType(fty->arg_types, fty->ret_type, {}, {}), map); + return Downcast(TypeSubst(FuncType(fty->arg_types, fty->ret_type, {}, {}), map)); } Function Specialize(const FunctionNode* f, const Array type_args) { @@ -302,6 +266,28 @@ class DefuncMutator : public ExprMutator { auto copy = TypeSubst(Function(f->params, f->body, f->ret_type, {}), map); return Downcast(copy); } + + Function FirstifyVars(const Function& f) { + CHECK(f->type_params.size() == 0) << "firstify function has type params"; + + std::unordered_map var_bind_map; + for (auto var: f->params) { + if (auto var_type = var->checked_type().as()) { + // first order parameter + auto fop_type = GetRef(var_type); + if (func_encoding.count(fop_type) == 0) { + auto name = "T" + TypeToString(fop_type); + func_encoding[fop_type] = GlobalTypeVar(name, TypeKind::kAdtHandle); + } + auto adt = func_encoding[fop_type]; + var_bind_map[var] = Var(var->name_hint(), TypeCall(adt, {})); + } else { + CHECK(!HasFuncType(var->checked_type())) << "nested function type in parameter not supported yet"; + } + } + + return Downcast(Bind(f, var_bind_map)); + } }; Expr Defunctionalization(const Expr& e, const IRModule& mod) { @@ -311,6 +297,7 @@ Expr Defunctionalization(const Expr& e, const IRModule& mod) { for (const auto& p : f->params) { CHECK(!HasFuncType(p)) << "input parameters cannot have func type"; } + CHECK(!HasFuncType(f->ret_type)) << "return type cannot contain function"; return DefuncMutator(mod).VisitExpr(e); } diff --git a/tests/python/relay/test_pass_defunctionalization.py b/tests/python/relay/test_pass_defunctionalization.py index b417bf4c41f1..13725f52dc27 100644 --- a/tests/python/relay/test_pass_defunctionalization.py +++ b/tests/python/relay/test_pass_defunctionalization.py @@ -59,9 +59,9 @@ def @main(%l: List[float32]) -> List[float32] { } """ mod = tvm.parser.fromtext(code) - mod = LambdaLift()(mod) + # mod = LambdaLift()(mod) mod = InferType()(mod) - expr = Defunctionalization(mod['main'], mod) + # expr = Defunctionalization(mod['main'], mod) def test_sum(): code = """ @@ -103,4 +103,4 @@ def @main(%f: float32) -> float32 { # pytest.main([__file__]) # test_simple() # test_global_recursion() - test() \ No newline at end of file + test_global_recursion() \ No newline at end of file From 2e27fbbf2f93de468e64ff67ea069742c40f760d Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Thu, 3 Sep 2020 16:34:51 -0700 Subject: [PATCH 07/17] revert type_infer --- src/relay/transforms/type_infer.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index ee0c6268ebc2..5579d8744e5a 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -375,6 +375,10 @@ class TypeInferencer : private ExprFunctor, subst_map.Set(fn_ty->type_params[i], ty_args[i]); } + for (size_t i = ty_args.size(); i < fn_ty->type_params.size(); ++i) { + subst_map.Set(fn_ty->type_params[i], IncompleteType(Kind::kType)); + } + Type ret_type = fn_ty->ret_type; // If the function type is incomplete, place a new IncompleteType @@ -659,7 +663,6 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { } if (need_update_call) { - std::cout << new_e << std::endl; new_call->type_args = it->second.type_args; for (size_t i = 0; i < new_call->type_args.size(); i++) { new_call->type_args.Set(i, solver_->Resolve(new_call->type_args[i])); From 0fec2f79b7a106764ef31f80704863f27e1d5727 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Fri, 4 Sep 2020 00:39:49 -0700 Subject: [PATCH 08/17] working --- src/relay/transforms/defunctionalization.cc | 331 ++++++++++++------ .../relay/test_pass_defunctionalization.py | 113 ++++-- 2 files changed, 308 insertions(+), 136 deletions(-) diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index ee389e0f538d..c6d12522ff8a 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -21,7 +21,42 @@ * * \file defunctionalization.cc * - * \brief + * \brief Defunctionalization for Relay IR + * + * This pass transforms a higher-order program into a first-order program with defunctionalization. + * This means that all higher order functions (i.e functions that take function arguments or return + * functions) should be transformed into a semantically equivalent first order one. + * + * This pass implements a basic typed defunctionalization method. + * All higher order functions are cloned and specialized (so that there are no type params). + * Function type arguments are encoded as datatypes and a helper `apply` function is used + * to "call" them. + * + * For example, take the following higher order program: + * fun map F y = case y of + * Nil => Nil + * | Cons(x, XS) => Cons(F z, map F XS) + * fun addone 1 = map (\x -> \x + 1) 1 + * + * where `addone` is our program. + * When we call the `map` function, we see that it is a higher-order function, + * but we can clone `map ` function and specialize it with the type_params of the call. + * In addition, our function argument `(\x -> \x + 1)` will be encoded as a datatype constructor, + * which we will call `incr`, and all calls to `F` in our specialized map function will use the + * helper `apply` function. + * + * After defunctionalization, we get: + * fun apply encoding arg = case encoding of + * “incr” => incr arg + * fun map’ F y = case y of + * Nil => Nil + * | Cons(x, xs) => Cons(apply F x, map’ F xs) + * fun addone 1 = map’ “incr” 1 + * + * Currently, defunctionalization makes the following assumptions: + * - functions cannot return function values + * - function arguments are in two forms: identifier or a lambda abstraction + * - no functions stored in datatype */ #include @@ -42,137 +77,166 @@ struct FuncTypeVisitor : TypeVisitor { void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; } }; - +// determine if expr contains a FuncType bool HasFuncType(const Expr& e) { auto visitor = FuncTypeVisitor(); visitor.VisitType(e->checked_type()); return visitor.has_func_type; } - +// determine if type contains a FuncType bool HasFuncType(const Type& t) { auto visitor = FuncTypeVisitor(); visitor.VisitType(t); return visitor.has_func_type; } - +// determine if FuncType is a higher order type bool IsHigherOrderFunc(const FuncType& t) { bool higher_order = false; - for (auto arg: t->arg_types) { + for (auto arg : t->arg_types) { higher_order |= HasFuncType(arg); } return higher_order |= HasFuncType(t->ret_type); } -// Array InferTypeArgs(const CallNode* call, const IRModule& mod) { -// ErrorReporter err; -// TypeSolver solver(mod->GetGlobalVar("main"), mod, &err); -// const FuncTypeNode* fn_ty = call->op->checked_type().as(); - -// tvm::Map subst_map; -// for (auto& tv: fn_ty->type_params) { -// subst_map.Set(tv, IncompleteType(Kind::kType)); -// } - -// auto inst_fnty = FuncType(fn_ty->arg_types, fn_ty->ret_type, {}, {}); -// auto f_incomplete = Downcast(Bind(inst_fnty, subst_map)); - -// CHECK(call->args.size() == f_incomplete->arg_types.size()) << "num of arguments does not match expected"; -// size_t num_args = f_incomplete->arg_types.size(); -// for (size_t i = 0; i < num_args; i++) { -// auto t1 = f_incomplete->arg_types[i]; -// auto t2 = call->args[i]->checked_type(); -// auto t = solver.Unify(t1, t2, GetRef(call)); -// } -// Array ret; -// for (auto& tv: fn_ty->type_params) { -// std::cout << "Resolved Type: " << solver.Resolve(subst_map[tv]) << std::endl; -// ret.push_back(solver.Resolve(subst_map[tv])); -// } -// return ret; -// } - +/*! + * \brief apply Defunctionalization transform + */ class DefuncMutator : public ExprMutator { public: - DefuncMutator(const IRModule& mod) : mod(mod), constructor_name(0), anon_name(0) {} + DefuncMutator(const IRModule& mod) : mod(mod), constructor_counter(0) {} Expr VisitExpr_(const CallNode* call) { if (auto op = call->op.as()) { - CHECK(call->type_args.size() == op->checked_type().as()->type_params.size()) << "all type args must be explicit"; + CHECK(call->type_args.size() == op->checked_type().as()->type_params.size()) + << "all type args must be explicit"; auto op_type = InstFuncType(op->checked_type().as(), call->type_args); - + CHECK(FreeTypeVars(op_type, mod).size() == 0) << "free type vars in instantiated"; CHECK(!HasFuncType(op_type->ret_type)) << "returning functions not supported"; + if (!IsHigherOrderFunc(op_type)) { + // not higher order function return ExprMutator::VisitExpr_(call); } - auto name = op->name_hint + TypeToString(op_type); - auto gv = GlobalVar(name); - if (mod->ContainGlobalVar(name)) { - gv = mod->GetGlobalVar(name); - } else { - // clone and specialize with specific type - auto clone = DeDup(DeGlobal(mod, GetRef(op))).as(); - auto specialized_function = Specialize(clone, call->type_args); - auto f = Downcast(this->VisitExpr(FirstifyVars(specialized_function))); - mod->Add(gv, f); - } + // first we encode function arguments Array args; for (size_t i = 0; i < call->args.size(); i++) { auto arg = call->args[i]; auto type = op_type->arg_types[i]; - // we assume arg is either an identifier or a function if (!HasFuncType(type)) { args.push_back(arg); continue; } + // we assume arg is either an identifier (var or globalvar) or a function CHECK(type.as()) << "assume no nested functions"; + CHECK(arg.as() || arg.as() || arg.as()) + << "assume all first-order-parameters are identifiers or functions"; if (arg.as()) { + // variable with functype will be encoded as datatype in surrounding function args.push_back(arg); } if (arg.as()) { args.push_back(EncodeGlobalVar(Downcast(arg), Downcast(type))); } - if (arg.as()) { - + if (auto fn = arg.as()) { + // we handle free vars in anonymous functions by adding arguments to + // the constructor function + auto free_vars = FreeVars(arg); + auto ft = Downcast(type); + + auto arg_types = Array(); + auto pattern_vars = Array(); + auto call_args = Array(); + Map free_var_bind_map; + for (auto free_var : free_vars) { + // free vars are already encoded + if (free_var->type_annotation.defined()) { + arg_types.push_back(free_var->type_annotation); + } else { + arg_types.push_back(free_var->checked_type()); + } + auto new_var = Var(free_var->name_hint(), free_var->type_annotation); + free_var_bind_map.Set(free_var, new_var); + pattern_vars.push_back(PatternVar(new_var)); + call_args.push_back(free_var); + } + auto gtv = GetFuncEncode(ft); + auto c = Constructor(std::to_string(constructor_counter++), arg_types, gtv); + AddConstructor(gtv, c); + + auto apply_gv = GetApplyFunction(ft); + auto body = this->VisitExpr(Bind(fn->body, free_var_bind_map)); + AddApplyCase(apply_gv, ft, c, Function(fn->params, body, fn->ret_type, fn->type_params), + pattern_vars); + + args.push_back(Call(c, call_args)); } - CHECK(false) << "assume all first-order-parameters are identifiers or functions"; } - + auto name = op->name_hint + TypeToString(op_type); + auto gv = GlobalVar(name); + if (specialized_gv_map.count(name)) { + gv = specialized_gv_map[name]; + } else { + specialized_gv_map[name] = gv; + // clone and specialize with specific type + auto clone = Downcast(DeDup(DeGlobal(mod, GetRef(op)))); + auto specialized_function = Specialize(clone, call->type_args); + // change var types and change all applications to use `apply` method + auto f = Downcast(FirstifyVars(specialized_function)); + mod->Add(gv, f); + } + return Call(gv, args); } else if (auto op = call->op.as()) { + // reduction by applying vars std::unordered_map var_binding_map; for (size_t i = 0; i < op->params.size(); i++) { var_binding_map[op->params[i]] = call->args[i]; - } + } auto e = Bind(op->body, var_binding_map); return this->VisitExpr(e); } else if (auto op = call->op.as()) { - auto op_type = InstFuncType(var_save_type[GetRef(op)].as(), call->type_args); - + // var node will be encoded as datatype + // so we need to use the `apply` helper method + auto var_original_type = GetUnencodedType(op->type_annotation).as(); + CHECK(var_original_type) << "var original type not saved in var_save_type map"; + auto op_type = InstFuncType(var_original_type, call->type_args); + Array args = {GetRef(op)}; - for (auto arg: call->args) { + for (auto arg : call->args) { args.push_back(this->VisitExpr(arg)); } - auto e = Call(apply_map[op_type], args); + auto e = Call(GetApplyFunction(op_type), args); return e; } return ExprMutator::VisitExpr_(call); } private: + // module IRModule mod; - // encode func type to ADT - std::unordered_map func_encoding; - std::unordered_map apply_map; - std::unordered_map var_save_type; - std::unordered_map, ObjectHash, ObjectEqual> gv_datatype_map; + // gv + str(type) to specialized clone gv + std::unordered_map specialized_gv_map; + // str(func_type) to ADT + std::unordered_map func_encoding; + // str(func_tyoe) to apply gv + std::unordered_map apply_map; + // encoded ADT handle to FuncType + std::unordered_map original_func_type_map; + // gv to (str(func_type) to constructor encoding) + std::unordered_map, ObjectHash, + ObjectEqual> + gv_datatype_map; // use monotonically increasing integer to represent new constructor_name - unsigned int constructor_name; - unsigned int anon_name; + unsigned long constructor_counter; + /*! + * \brief add a constructor to the GlobalTypeVar, creating a new TypeDef if GlobalTypeVar does not + * exist + */ void AddConstructor(GlobalTypeVar gtv, Constructor c) { if (!mod->ContainGlobalTypeVar(gtv->name_hint)) { mod->AddTypeDef(gtv, TypeData(gtv, {}, {c})); @@ -183,25 +247,37 @@ class DefuncMutator : public ExprMutator { mod->UpdateTypeDef(gtv, TypeData(typedata->header, typedata->type_vars, constructors)); } } - - void AddApplyCase(GlobalVar gv, FuncType ft, Constructor c, const Expr& expr) { - if (!mod->ContainGlobalVar(gv->name_hint)) { + /*! + * \brief add a case to the apply function, creating the function if it does not exist + * + * \param apply_gv GlobalVar of the apply function + * \param ft is the type functions the apply function handles + * \param c constructor to add a case for + * \param expr calls this expr with the args to the apply_gv + * \param patterns PatterVars to match with the constructor, used for handling free vars in + * functions + */ + void AddApplyCase(GlobalVar apply_gv, FuncType ft, Constructor c, const Expr& expr, + const Array patterns) { + CHECK(c->inputs.size() == patterns.size()) + << "constructor function and pattern vars have different sizes"; + if (!mod->ContainGlobalVar(apply_gv->name_hint)) { auto x = Var("x", TypeCall(c->belong_to, {})); auto vars = Array({x}); auto args = Array(); - for (auto t: ft->arg_types) { + for (auto t : ft->arg_types) { auto y = Var("y", t); vars.push_back(y); args.push_back(y); } - auto clauses = Array({Clause(PatternConstructor(c, {}), Call(expr, args))}); + auto clauses = Array({Clause(PatternConstructor(c, patterns), Call(expr, args))}); auto body = Match(x, clauses); auto f = Function(vars, body, ft->ret_type, {}); - mod->Add(gv, f); + mod->Add(apply_gv, f); } else { - auto f = Downcast(mod->Lookup(gv)); + auto f = Downcast(mod->Lookup(apply_gv)); auto body = f->body.as(); CHECK(body) << "internal invariant broken; apply function body should be a match node"; @@ -211,44 +287,78 @@ class DefuncMutator : public ExprMutator { for (size_t i = 1; i < f->params.size(); i++) { args.push_back(f->params[i]); } - clauses.push_back(Clause(PatternConstructor(c, {}), Call(expr, args))); + clauses.push_back(Clause(PatternConstructor(c, patterns), Call(expr, args))); - mod->Add(gv, Function(f->params, Match(x, clauses), f->ret_type, f->type_params), true); + mod->Add(apply_gv, Function(f->params, Match(x, clauses), f->ret_type, f->type_params), true); } } + /*! + * \brief encode a global var with a specialized type with a datatype + */ Expr EncodeGlobalVar(const GlobalVar& gv, const FuncType& ft) { auto map = gv_datatype_map[gv]; - if (map.count(ft) == 0) { - auto adt_name = "T" + TypeToString(ft); - - if (func_encoding.count(ft) == 0) { - func_encoding[ft] = GlobalTypeVar(adt_name, TypeKind::kAdtHandle); - } - - auto gtv = func_encoding[ft]; - auto c = Constructor(std::to_string(constructor_name++), {}, gtv); + auto type_key = TypeToString(ft); + if (map.count(type_key) == 0) { + auto gtv = GetFuncEncode(ft); + auto c = Constructor(std::to_string(constructor_counter++), {}, gtv); + map[type_key] = c; AddConstructor(gtv, c); - - if (apply_map.count(ft) == 0) { - apply_map[ft] = GlobalVar("apply" + TypeToString(ft)); - } - - auto gv = apply_map[ft]; - AddApplyCase(gv, ft, c, gv); + AddApplyCase(GetApplyFunction(ft), ft, c, gv, {}); } - - auto c = map[ft]; - return Call(c, {}); + return Call(map[type_key], {}); } - + + /*! + * \brief type to string + */ std::string TypeToString(const Type& t) { std::ostringstream s; s << t; return s.str(); } + /*! + * \brief get ADT handle for encoding type t + */ + GlobalTypeVar GetFuncEncode(const Type& t) { + auto adt_name = "T" + TypeToString(t); + if (func_encoding.count(adt_name) == 0) { + func_encoding[adt_name] = GlobalTypeVar(adt_name, TypeKind::kAdtHandle); + } + original_func_type_map[func_encoding[adt_name]] = t; + return func_encoding[adt_name]; + } + + /*! + * \brief get original function type represented by type t + */ + FuncType GetUnencodedType(const Type& t) { + auto tc = t.as(); + CHECK(tc) << "expected type call when getting original type from encoded type"; + auto gv = tc->func.as(); + CHECK(gv) << "expected global type var in encoded type"; + auto type = original_func_type_map[GetRef(gv)]; + CHECK(type.defined()) << "reverse mapping from encoded type to original type not found"; + return Downcast(type); + } + + /*! + * \brief get the apply function for calling datatypes encoding functions of type t + */ + GlobalVar GetApplyFunction(const Type& t) { + auto f_name = "apply" + TypeToString(t); + if (apply_map.count(f_name) == 0) { + apply_map[f_name] = GlobalVar("apply" + TypeToString(t)); + } + return apply_map[f_name]; + } + + /*! + * \brief specialize a function type + */ FuncType InstFuncType(const FuncTypeNode* fty, const Array type_args) { + CHECK(fty) << "InstFuncType functype is null"; auto map = tvm::Map(); for (size_t i = 0; i < type_args.size(); i++) { map.Set(fty->type_params[i], type_args[i]); @@ -257,7 +367,10 @@ class DefuncMutator : public ExprMutator { return Downcast(TypeSubst(FuncType(fty->arg_types, fty->ret_type, {}, {}), map)); } - Function Specialize(const FunctionNode* f, const Array type_args) { + /*! + * \brief specialize a function expression + */ + Function Specialize(const Function& f, const Array type_args) { auto map = tvm::Map(); for (size_t i = 0; i < type_args.size(); i++) { map.Set(f->type_params[i], type_args[i]); @@ -267,30 +380,38 @@ class DefuncMutator : public ExprMutator { return Downcast(copy); } + /*! + * \brief transform a function to be first order by transforming arg_types and + * using the `apply` function for applications + */ Function FirstifyVars(const Function& f) { CHECK(f->type_params.size() == 0) << "firstify function has type params"; - std::unordered_map var_bind_map; - for (auto var: f->params) { - if (auto var_type = var->checked_type().as()) { + tvm::Map var_bind_map; + Array params; + for (auto var : f->params) { + if (auto var_type = var->type_annotation.as()) { // first order parameter auto fop_type = GetRef(var_type); - if (func_encoding.count(fop_type) == 0) { - auto name = "T" + TypeToString(fop_type); - func_encoding[fop_type] = GlobalTypeVar(name, TypeKind::kAdtHandle); - } - auto adt = func_encoding[fop_type]; - var_bind_map[var] = Var(var->name_hint(), TypeCall(adt, {})); + auto adt = GetFuncEncode(fop_type); + auto new_var = Var(var->name_hint(), TypeCall(adt, {})); + mod->LookupTypeDef(adt); + var_bind_map.Set(var, new_var); + params.push_back(new_var); } else { - CHECK(!HasFuncType(var->checked_type())) << "nested function type in parameter not supported yet"; + CHECK(!HasFuncType(var->type_annotation)) + << "nested function type in parameter not supported yet"; + params.push_back(var); } } - return Downcast(Bind(f, var_bind_map)); + auto bind = Downcast(Bind(f, var_bind_map)); + return Function(params, this->VisitExpr(bind->body), bind->ret_type, {}); } }; Expr Defunctionalization(const Expr& e, const IRModule& mod) { + // e is the starting point of the program, all types MUST be known auto f = e.as(); CHECK(f) << "input need to be a function"; CHECK(f->type_params.size() == 0) << "no polymorphism supported for defunctionalization"; diff --git a/tests/python/relay/test_pass_defunctionalization.py b/tests/python/relay/test_pass_defunctionalization.py index 13725f52dc27..dba15440a5a3 100644 --- a/tests/python/relay/test_pass_defunctionalization.py +++ b/tests/python/relay/test_pass_defunctionalization.py @@ -18,25 +18,90 @@ import tvm from tvm import relay -from tvm.relay.transform import Defunctionalization, InferType, LambdaLift +from tvm.relay import transform, ExprVisitor, TypeVisitor +from tvm.relay.testing import Prelude + +# determine if type t is a FuncType or has a nested FuncType +def has_func_type(t): + class FuncTypeVisitor(TypeVisitor): + def __init__(self): + super().__init__() + self.has_func = False + + def visit_func_type(self, ftt): + self.has_func = True + + ftvisitor = FuncTypeVisitor() + ftvisitor.visit(t) + return ftvisitor.has_func + +# determine whether a program has any higher order functions +# a higher order function is defined as one that: +# - has function type arguments +# - returns a function +def assert_no_higher_order_functions(expr, mod): + class CheckFirstOrderVisitor(ExprVisitor): + def __init__(self, mod): + super().__init__() + self.mod = mod + self.hof = [] + self.visited_gv = set() + + def visit_call(self, call): + is_higher_order = False + # check return type + if (has_func_type(call.checked_type)): + is_higher_order = True + # check argument types + for a in call.args: + if (has_func_type(a.checked_type)): + is_higher_order = True + # if it is higher order, save it or debugging later + if is_higher_order: + self.hof.append(call) + super().visit_call(call) + + def visit_global_var(self, gv): + # visit global vars to visit entire program + if gv not in self.visited_gv: + self.visited_gv.add(gv) + self.visit(self.mod[gv]) + + mod = transform.InferType()(mod) + check_fo_visitor = CheckFirstOrderVisitor(mod) + check_fo_visitor.visit(expr) + + nl = '\n--------\n' + errmsg = f"""found {len(check_fo_visitor.hof)} higher order functions: + {nl.join(expr.astext() for expr in check_fo_visitor.hof)}""" + + assert len(check_fo_visitor.hof) == 0, errmsg + +# assert that a program is defunctionalized +# assumes program starts from mod['main'] +def defunctionalized(mod): + mod = transform.InferType()(mod) + mod['main'] = transform.Defunctionalization(mod['main'], mod) + mod = transform.InferType()(mod) + assert_no_higher_order_functions(mod['main'], mod) + + return mod def test_simple(): - code = """ + code = """ #[version = "0.0.5"] -def @apply[A, B](%f: fn(A) -> B, %xs: A) -> B { +def @simple[A, B](%f: fn(A) -> B, %xs: A) -> B { %f(%xs) } def @main(%l: float32) -> float32 { %0 = fn[A](%x: A) -> A { %x }; - @apply(%0, %l) + @simple(%0, %l) } """ - mod = tvm.parser.fromtext(code) - mod = LambdaLift()(mod) - mod = InferType()(mod) - expr = Defunctionalization(mod['main'], mod) + mod = tvm.parser.fromtext(code) + defunc_mod = defunctionalized(mod) def test_global_recursion(): code = """ @@ -59,48 +124,34 @@ def @main(%l: List[float32]) -> List[float32] { } """ mod = tvm.parser.fromtext(code) - # mod = LambdaLift()(mod) - mod = InferType()(mod) - # expr = Defunctionalization(mod['main'], mod) + defunc_mod = defunctionalized(mod) -def test_sum(): +def test_recursive_datatype(): + # CPS will create recursive datatype code = """ #[version = "0.0.5"] type List[A] { Cons(A, List[A]), Nil, } -def @main(%f: fn(int32) -> int32, %xs: List[int32]) -> int32 { +def @sum(%f: fn(int32) -> int32, %xs: List[int32]) -> int32 { match (%xs) { Cons(%x, %rest) => %0 = fn(%n) { %x + %f(%n) }; - @main(%0, %rest), + @sum(%0, %rest), Nil => %f(0), } } -""" - mod = tvm.parser.fromtext(code) - mod = LambdaLift()(mod) - mod = InferType()(mod) - print(mod) - -def test(): - code = """ -#[version = "0.0.5"] def @id[A](%x: A) -> A { %x } -def @main(%f: float32) -> float32 { - @id(@id)(%f) +def @main(%l: List[int32]) -> int32 { + @sum(@id, %l) } """ mod = tvm.parser.fromtext(code) - mod = InferType()(mod) - print(mod['main'].body.type_args) + defunc_mod = defunctionalized(mod) if __name__ == "__main__": - # pytest.main([__file__]) - # test_simple() - # test_global_recursion() - test_global_recursion() \ No newline at end of file + pytest.main([__file__]) \ No newline at end of file From 0b9573c9070114d13d3e825260688ea750ab745a Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Fri, 4 Sep 2020 09:51:54 -0700 Subject: [PATCH 09/17] fix up test --- src/relay/transforms/defunctionalization.cc | 4 +- .../relay/test_pass_defunctionalization.py | 73 ++++++++++++++++++- 2 files changed, 74 insertions(+), 3 deletions(-) diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index c6d12522ff8a..57e80fe3794c 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -57,6 +57,7 @@ * - functions cannot return function values * - function arguments are in two forms: identifier or a lambda abstraction * - no functions stored in datatype + * - functions are not let binded */ #include @@ -152,7 +153,8 @@ class DefuncMutator : public ExprMutator { auto call_args = Array(); Map free_var_bind_map; for (auto free_var : free_vars) { - // free vars are already encoded + // free vars are already encoded, can only exist within + // specialized functions if (free_var->type_annotation.defined()) { arg_types.push_back(free_var->type_annotation); } else { diff --git a/tests/python/relay/test_pass_defunctionalization.py b/tests/python/relay/test_pass_defunctionalization.py index dba15440a5a3..0dc006cdaf0b 100644 --- a/tests/python/relay/test_pass_defunctionalization.py +++ b/tests/python/relay/test_pass_defunctionalization.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. import pytest +import numpy as np import tvm from tvm import relay +from tvm.relay.backend.interpreter import ConstructorValue from tvm.relay import transform, ExprVisitor, TypeVisitor from tvm.relay.testing import Prelude @@ -77,7 +79,8 @@ def visit_global_var(self, gv): assert len(check_fo_visitor.hof) == 0, errmsg -# assert that a program is defunctionalized +# assert that a program is defunctionalized and returns +# defunctionalized module # assumes program starts from mod['main'] def defunctionalized(mod): mod = transform.InferType()(mod) @@ -87,13 +90,48 @@ def defunctionalized(mod): return mod +# adt list to python list +def to_list(mod, l): + list = mod.get_global_type_var('List') + list_adt = mod[list] + cons = list_adt.constructors[0] + nil = list_adt.constructors[1] + + assert isinstance(l, ConstructorValue) + val = l + ret = [] + while True: + if val.tag == cons.tag: + ret.append(val.fields[0].asnumpy()) + val = val.fields[1] + else: + assert val.tag == nil.tag + break + return ret + +# list to adt list +def to_adt_list(mod, arr): + expr = mod['main'] + l = mod.get_global_type_var('List') + list_adt = mod[l] + cons = list_adt.constructors[0] + nil = list_adt.constructors[1] + + li = nil() + for a in arr: + li = cons(relay.const(a), li) + ex = relay.create_executor(mod=mod) + adt = ex.evaluate(li) + mod['main'] = expr + return adt + def test_simple(): code = """ #[version = "0.0.5"] def @simple[A, B](%f: fn(A) -> B, %xs: A) -> B { %f(%xs) } -def @main(%l: float32) -> float32 { +def @main(%l: Tensor[(5, 5), float32]) -> Tensor[(5, 5), float32] { %0 = fn[A](%x: A) -> A { %x }; @@ -103,6 +141,17 @@ def @main(%l: float32) -> float32 { mod = tvm.parser.fromtext(code) defunc_mod = defunctionalized(mod) + input = np.random.rand(5,5).astype('float32') + + ex = relay.create_executor('debug', mod=mod) + defunc_ex = relay.create_executor('debug', mod=defunc_mod) + + out = ex.evaluate()(input) + defunc_out = defunc_ex.evaluate()(input) + + np.testing.assert_equal(out.asnumpy(), defunc_out.asnumpy()) + + def test_global_recursion(): code = """ #[version = "0.0.5"] @@ -126,6 +175,16 @@ def @main(%l: List[float32]) -> List[float32] { mod = tvm.parser.fromtext(code) defunc_mod = defunctionalized(mod) + input = np.random.rand(10).astype('float32') + + ex = relay.create_executor('debug', mod=mod) + defunc_ex = relay.create_executor('debug', mod=defunc_mod) + + out = ex.evaluate(mod['main'])(to_adt_list(mod, input)) + defunc_out = defunc_ex.evaluate()(to_adt_list(defunc_mod, input)) + + np.testing.assert_array_equal(to_list(mod, out), to_list(defunc_mod, defunc_out)) + def test_recursive_datatype(): # CPS will create recursive datatype code = """ @@ -153,5 +212,15 @@ def @main(%l: List[int32]) -> int32 { mod = tvm.parser.fromtext(code) defunc_mod = defunctionalized(mod) + input = np.random.randint(1, 100, 10) + + ex = relay.create_executor('debug', mod=mod) + defunc_ex = relay.create_executor('debug', mod=defunc_mod) + + out = ex.evaluate(mod['main'])(to_adt_list(mod, input)) + defunc_out = defunc_ex.evaluate()(to_adt_list(defunc_mod, input)) + + tvm.testing.assert_allclose(out.asnumpy(), defunc_out.asnumpy()) + if __name__ == "__main__": pytest.main([__file__]) \ No newline at end of file From 85bbd4d0a97ce266bb8b4fbd6a42b2434c666cc2 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Fri, 4 Sep 2020 09:57:54 -0700 Subject: [PATCH 10/17] fix --- python/tvm/relay/transform/transform.py | 2 ++ src/relay/transforms/pass_util.h | 2 +- src/relay/transforms/type_infer.cc | 4 ---- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index f64e80568ee1..5993d17e06f9 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -744,6 +744,8 @@ def Defunctionalization(expr, mod): The input expression, which is a Function or a GlobalVar. mod : tvm.IRModule + The IRModule containing function and type definitions, + which is mutated during this pass. Returns ------- diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h index 640ed788894f..d808b3143120 100644 --- a/src/relay/transforms/pass_util.h +++ b/src/relay/transforms/pass_util.h @@ -223,7 +223,7 @@ std::pair CalcScope(const DependencyGraph& dg); Scope LCA(Scope lhs, Scope rhs); /*! \brief if the expression is a GlobalVar, transform to it's expression. -*/ + */ Expr DeGlobal(const Optional& mod, const Expr& e); /* Special care is needed to handle local recursion. diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 5579d8744e5a..e110737d6226 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -375,10 +375,6 @@ class TypeInferencer : private ExprFunctor, subst_map.Set(fn_ty->type_params[i], ty_args[i]); } - for (size_t i = ty_args.size(); i < fn_ty->type_params.size(); ++i) { - subst_map.Set(fn_ty->type_params[i], IncompleteType(Kind::kType)); - } - Type ret_type = fn_ty->ret_type; // If the function type is incomplete, place a new IncompleteType From 8dd7dc3c60710cd83e2ee834e42c5f8bd58a4f80 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Fri, 4 Sep 2020 10:01:20 -0700 Subject: [PATCH 11/17] remove DeGlobal --- src/relay/analysis/util.cc | 15 --------------- src/relay/transforms/defunctionalization.cc | 2 +- src/relay/transforms/gradient.cc | 16 ++++++++++++++++ src/relay/transforms/pass_util.h | 4 ---- 4 files changed, 17 insertions(+), 20 deletions(-) diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index ef25d0e54add..b98106a091b3 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -449,21 +449,6 @@ Expr TypeSubst(const Expr& expr, const tvm::Map& subst_map) { return ret; } -Expr DeGlobal(const Optional& mod, const Expr& e) { - const auto* x = e.as(); - - if (mod.defined() && x) { - BaseFunc base_func = mod.value()->Lookup(GetRef(x)); - if (auto* n = base_func.as()) { - return GetRef(n); - } else { - return e; - } - } else { - return e; - } -} - struct IsDynamicVisitor : public TypeVisitor { bool is_dyn{false}; void VisitType_(const TensorTypeNode* tt) { diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index 57e80fe3794c..43fa84c87f44 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -184,7 +184,7 @@ class DefuncMutator : public ExprMutator { } else { specialized_gv_map[name] = gv; // clone and specialize with specific type - auto clone = Downcast(DeDup(DeGlobal(mod, GetRef(op)))); + auto clone = Downcast(DeDup(mod->Lookup(GetRef(op)))); auto specialized_function = Specialize(clone, call->type_args); // change var types and change all applications to use `apply` method auto f = Downcast(FirstifyVars(specialized_function)); diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index 81ce64afe7a5..9c472542cc91 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -78,6 +78,22 @@ Type WithGradientType(const Type& t) { return FuncType(ty->arg_types, TupleType({ty->ret_type, TupleType(ty->arg_types)}), {}, {}); } +//! \brief if the expression is a GlobalVar, transform to it's expression. +Expr DeGlobal(const Optional& mod, const Expr& e) { + const auto* x = e.as(); + + if (mod.defined() && x) { + BaseFunc base_func = mod.value()->Lookup(GetRef(x)); + if (auto* n = base_func.as()) { + return GetRef(n); + } else { + return e; + } + } else { + return e; + } +} + /*! \brief A fragment of the program being built by the automatic differentation * pass. */ diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h index d808b3143120..f3c99ccfa120 100644 --- a/src/relay/transforms/pass_util.h +++ b/src/relay/transforms/pass_util.h @@ -222,10 +222,6 @@ std::pair CalcScope(const DependencyGraph& dg); */ Scope LCA(Scope lhs, Scope rhs); -/*! \brief if the expression is a GlobalVar, transform to it's expression. - */ -Expr DeGlobal(const Optional& mod, const Expr& e); - /* Special care is needed to handle local recursion. * Fill additionally take a (possibly null) Var argument, * If it is not null, Fill is required to bind the transformed result to that var. From 572703e765e0b6d86e778ec1f6a6f2a035659997 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Fri, 4 Sep 2020 10:08:27 -0700 Subject: [PATCH 12/17] lint --- src/relay/transforms/defunctionalization.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index 43fa84c87f44..3b0f1a6f1b03 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -100,11 +100,11 @@ bool IsHigherOrderFunc(const FuncType& t) { } /*! - * \brief apply Defunctionalization transform + * \brief mutator for driving the Defunctionalization transformation */ class DefuncMutator : public ExprMutator { public: - DefuncMutator(const IRModule& mod) : mod(mod), constructor_counter(0) {} + explicit DefuncMutator(const IRModule& mod) : mod(mod), constructor_counter(0) {} Expr VisitExpr_(const CallNode* call) { if (auto op = call->op.as()) { @@ -233,7 +233,7 @@ class DefuncMutator : public ExprMutator { ObjectEqual> gv_datatype_map; // use monotonically increasing integer to represent new constructor_name - unsigned long constructor_counter; + uint64_t constructor_counter; /*! * \brief add a constructor to the GlobalTypeVar, creating a new TypeDef if GlobalTypeVar does not From 14c22972d1b3f781adb320a47c961b76fbcb9945 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Fri, 4 Sep 2020 10:55:05 -0700 Subject: [PATCH 13/17] fix std move --- src/relay/transforms/defunctionalization.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index 3b0f1a6f1b03..ed45590e3452 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -211,8 +211,7 @@ class DefuncMutator : public ExprMutator { args.push_back(this->VisitExpr(arg)); } - auto e = Call(GetApplyFunction(op_type), args); - return e; + return Call(GetApplyFunction(op_type), args); } return ExprMutator::VisitExpr_(call); } From ad71b1ebcd643764dfd6dcab67ae7d2370480b5f Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Fri, 4 Sep 2020 11:25:37 -0700 Subject: [PATCH 14/17] comments --- python/tvm/relay/transform/transform.py | 19 +++++++++++++------ src/relay/transforms/defunctionalization.cc | 10 ++++------ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 5993d17e06f9..0ee107027631 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -736,21 +736,28 @@ def gradient(expr, mod=None, mode='higher_order'): return _ffi_api.gradient(expr, mod) raise Exception('unknown mode') -def Defunctionalization(expr, mod): +def Defunctionalization(func, mod): """ + Performs defunctionalization on func, + transforming func from a higher-order program to a first-order program. + + At each call site, the function is cloned and type parameters are substituted in. + Function arguments are encoded as datatypes + and additional apply functions are used for application. + Parameters ---------- - expr : tvm.relay.Expr - The input expression, which is a Function or a GlobalVar. + func : tvm.relay.Function + The input function, which should not be polymorphic or be higher-order. mod : tvm.IRModule The IRModule containing function and type definitions, - which is mutated during this pass. + which is also mutated during this pass. Returns ------- - expr : tvm.relay.Expr - The transformed expression. + expr : tvm.relay.Function + The output function. """ return _ffi_api.Defunctionalization(expr, mod) diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index ed45590e3452..e7776f5cc0d5 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -411,17 +411,15 @@ class DefuncMutator : public ExprMutator { } }; -Expr Defunctionalization(const Expr& e, const IRModule& mod) { - // e is the starting point of the program, all types MUST be known - auto f = e.as(); - CHECK(f) << "input need to be a function"; +Expr Defunctionalization(const Function& f, const IRModule& mod) { + // f is the starting point of the program, all types MUST be known CHECK(f->type_params.size() == 0) << "no polymorphism supported for defunctionalization"; for (const auto& p : f->params) { - CHECK(!HasFuncType(p)) << "input parameters cannot have func type"; + CHECK(!HasFuncType(p)) << "program cannot have func type parameters"; } CHECK(!HasFuncType(f->ret_type)) << "return type cannot contain function"; - return DefuncMutator(mod).VisitExpr(e); + return Downcast(DefuncMutator(mod).VisitExpr(f)); } TVM_REGISTER_GLOBAL("relay._transform.Defunctionalization").set_body_typed(Defunctionalization); From 623d4b1cdda481d1daf3852cf3f7e14186bd6d11 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Fri, 4 Sep 2020 11:28:05 -0700 Subject: [PATCH 15/17] fix comments --- python/tvm/relay/transform/transform.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 0ee107027631..60a7aa38b369 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -749,6 +749,8 @@ def Defunctionalization(func, mod): ---------- func : tvm.relay.Function The input function, which should not be polymorphic or be higher-order. + This is because all types must be known and we can't encode function arguments + to the program itself. mod : tvm.IRModule The IRModule containing function and type definitions, @@ -759,7 +761,7 @@ def Defunctionalization(func, mod): expr : tvm.relay.Function The output function. """ - return _ffi_api.Defunctionalization(expr, mod) + return _ffi_api.Defunctionalization(func, mod) def to_cps(func, mod=None): """ From bbab26d681a02ff5cd083723d668e7c03452432e Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 9 Sep 2020 15:16:56 -0700 Subject: [PATCH 16/17] review --- src/relay/transforms/defunctionalization.cc | 128 +++++++++--------- .../relay/test_pass_defunctionalization.py | 6 +- 2 files changed, 69 insertions(+), 65 deletions(-) diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index e7776f5cc0d5..dc0baa950e95 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -72,20 +72,15 @@ namespace tvm { namespace relay { -struct FuncTypeVisitor : TypeVisitor { - bool has_func_type; - FuncTypeVisitor() : has_func_type(false) {} - - void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; } -}; -// determine if expr contains a FuncType -bool HasFuncType(const Expr& e) { - auto visitor = FuncTypeVisitor(); - visitor.VisitType(e->checked_type()); - return visitor.has_func_type; -} // determine if type contains a FuncType bool HasFuncType(const Type& t) { + struct FuncTypeVisitor : TypeVisitor { + bool has_func_type; + FuncTypeVisitor() : has_func_type(false) {} + + void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; } + }; + auto visitor = FuncTypeVisitor(); visitor.VisitType(t); return visitor.has_func_type; @@ -108,11 +103,11 @@ class DefuncMutator : public ExprMutator { Expr VisitExpr_(const CallNode* call) { if (auto op = call->op.as()) { - CHECK(call->type_args.size() == op->checked_type().as()->type_params.size()) + CHECK_EQ(call->type_args.size(), op->checked_type().as()->type_params.size()) << "all type args must be explicit"; auto op_type = InstFuncType(op->checked_type().as(), call->type_args); - CHECK(FreeTypeVars(op_type, mod).size() == 0) << "free type vars in instantiated"; + CHECK_EQ(FreeTypeVars(op_type, mod).size(), 0) << "free type vars in instantiated"; CHECK(!HasFuncType(op_type->ret_type)) << "returning functions not supported"; if (!IsHigherOrderFunc(op_type)) { @@ -130,52 +125,7 @@ class DefuncMutator : public ExprMutator { continue; } - // we assume arg is either an identifier (var or globalvar) or a function - CHECK(type.as()) << "assume no nested functions"; - CHECK(arg.as() || arg.as() || arg.as()) - << "assume all first-order-parameters are identifiers or functions"; - - if (arg.as()) { - // variable with functype will be encoded as datatype in surrounding function - args.push_back(arg); - } - if (arg.as()) { - args.push_back(EncodeGlobalVar(Downcast(arg), Downcast(type))); - } - if (auto fn = arg.as()) { - // we handle free vars in anonymous functions by adding arguments to - // the constructor function - auto free_vars = FreeVars(arg); - auto ft = Downcast(type); - - auto arg_types = Array(); - auto pattern_vars = Array(); - auto call_args = Array(); - Map free_var_bind_map; - for (auto free_var : free_vars) { - // free vars are already encoded, can only exist within - // specialized functions - if (free_var->type_annotation.defined()) { - arg_types.push_back(free_var->type_annotation); - } else { - arg_types.push_back(free_var->checked_type()); - } - auto new_var = Var(free_var->name_hint(), free_var->type_annotation); - free_var_bind_map.Set(free_var, new_var); - pattern_vars.push_back(PatternVar(new_var)); - call_args.push_back(free_var); - } - auto gtv = GetFuncEncode(ft); - auto c = Constructor(std::to_string(constructor_counter++), arg_types, gtv); - AddConstructor(gtv, c); - - auto apply_gv = GetApplyFunction(ft); - auto body = this->VisitExpr(Bind(fn->body, free_var_bind_map)); - AddApplyCase(apply_gv, ft, c, Function(fn->params, body, fn->ret_type, fn->type_params), - pattern_vars); - - args.push_back(Call(c, call_args)); - } + args.push_back(EncodeArg(arg, type)); } auto name = op->name_hint + TypeToString(op_type); auto gv = GlobalVar(name); @@ -294,6 +244,55 @@ class DefuncMutator : public ExprMutator { } } + Expr EncodeArg(const Expr& arg, const Type& type) { + // we assume arg is either an identifier (var or globalvar) or a function + CHECK(type.as()) << "assume no nested functions"; + CHECK(arg.as() || arg.as() || arg.as()) + << "assume all first-order-parameters are identifiers or functions"; + + if (arg.as()) { + // variable with functype will be encoded as datatype in surrounding function + return arg; + } else if (arg.as()) { + return EncodeGlobalVar(Downcast(arg), Downcast(type)); + } else if (auto fn = arg.as()) { + // we handle free vars in anonymous functions by adding arguments to + // the constructor function + auto free_vars = FreeVars(arg); + auto ft = Downcast(type); + + auto arg_types = Array(); + auto pattern_vars = Array(); + auto call_args = Array(); + Map free_var_bind_map; + for (auto free_var : free_vars) { + // free vars are already encoded, can only exist within + // specialized functions + if (free_var->type_annotation.defined()) { + arg_types.push_back(free_var->type_annotation); + } else { + arg_types.push_back(free_var->checked_type()); + } + auto new_var = Var(free_var->name_hint(), free_var->type_annotation); + free_var_bind_map.Set(free_var, new_var); + pattern_vars.push_back(PatternVar(new_var)); + call_args.push_back(free_var); + } + auto gtv = GetFuncEncode(ft); + auto c = Constructor(std::to_string(++constructor_counter), arg_types, gtv); + AddConstructor(gtv, c); + + auto apply_gv = GetApplyFunction(ft); + auto body = this->VisitExpr(Bind(fn->body, free_var_bind_map)); + AddApplyCase(apply_gv, ft, c, Function(fn->params, body, fn->ret_type, fn->type_params), + pattern_vars); + + return Call(c, call_args); + } + + throw std::runtime_error("EncodeArg failed to cast arg into identifier node or function node"); + } + /*! * \brief encode a global var with a specialized type with a datatype */ @@ -323,7 +322,7 @@ class DefuncMutator : public ExprMutator { * \brief get ADT handle for encoding type t */ GlobalTypeVar GetFuncEncode(const Type& t) { - auto adt_name = "T" + TypeToString(t); + auto adt_name = "Defunc" + TypeToString(t); if (func_encoding.count(adt_name) == 0) { func_encoding[adt_name] = GlobalTypeVar(adt_name, TypeKind::kAdtHandle); } @@ -360,6 +359,8 @@ class DefuncMutator : public ExprMutator { */ FuncType InstFuncType(const FuncTypeNode* fty, const Array type_args) { CHECK(fty) << "InstFuncType functype is null"; + CHECK_EQ(fty->type_params.size(), type_args.size()) + << "size mismatch between function type params and type args"; auto map = tvm::Map(); for (size_t i = 0; i < type_args.size(); i++) { map.Set(fty->type_params[i], type_args[i]); @@ -372,6 +373,9 @@ class DefuncMutator : public ExprMutator { * \brief specialize a function expression */ Function Specialize(const Function& f, const Array type_args) { + CHECK_EQ(f->type_params.size(), type_args.size()) + << "cannot specialize function with size mismatch between function type params and type " + "args"; auto map = tvm::Map(); for (size_t i = 0; i < type_args.size(); i++) { map.Set(f->type_params[i], type_args[i]); @@ -415,7 +419,7 @@ Expr Defunctionalization(const Function& f, const IRModule& mod) { // f is the starting point of the program, all types MUST be known CHECK(f->type_params.size() == 0) << "no polymorphism supported for defunctionalization"; for (const auto& p : f->params) { - CHECK(!HasFuncType(p)) << "program cannot have func type parameters"; + CHECK(!HasFuncType(p->checked_type())) << "program cannot have func type parameters"; } CHECK(!HasFuncType(f->ret_type)) << "return type cannot contain function"; diff --git a/tests/python/relay/test_pass_defunctionalization.py b/tests/python/relay/test_pass_defunctionalization.py index 0dc006cdaf0b..ac54ebff0271 100644 --- a/tests/python/relay/test_pass_defunctionalization.py +++ b/tests/python/relay/test_pass_defunctionalization.py @@ -58,7 +58,7 @@ def visit_call(self, call): for a in call.args: if (has_func_type(a.checked_type)): is_higher_order = True - # if it is higher order, save it or debugging later + # if it is higher order, save it for debugging later if is_higher_order: self.hof.append(call) super().visit_call(call) @@ -193,8 +193,8 @@ def test_recursive_datatype(): Cons(A, List[A]), Nil, } -def @sum(%f: fn(int32) -> int32, %xs: List[int32]) -> int32 { - match (%xs) { +def @sum(%f: fn(int32) -> int32, %k: List[int32]) -> int32 { + match (%k) { Cons(%x, %rest) => %0 = fn(%n) { %x + %f(%n) }; From a573c1769fb284abbe158d92918fd9a6b9dd6f3f Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 9 Sep 2020 21:11:05 -0700 Subject: [PATCH 17/17] style --- src/relay/transforms/defunctionalization.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index dc0baa950e95..ec614d23a02e 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -122,10 +122,9 @@ class DefuncMutator : public ExprMutator { auto type = op_type->arg_types[i]; if (!HasFuncType(type)) { args.push_back(arg); - continue; + } else { + args.push_back(EncodeArg(arg, type)); } - - args.push_back(EncodeArg(arg, type)); } auto name = op->name_hint + TypeToString(op_type); auto gv = GlobalVar(name);