diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index de3f9861c96e..60a7aa38b369 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -736,6 +736,32 @@ def gradient(expr, mod=None, mode='higher_order'): return _ffi_api.gradient(expr, mod) raise Exception('unknown mode') +def Defunctionalization(func, mod): + """ + Performs defunctionalization on func, + transforming func from a higher-order program to a first-order program. + + At each call site, the function is cloned and type parameters are substituted in. + Function arguments are encoded as datatypes + and additional apply functions are used for application. + + Parameters + ---------- + func : tvm.relay.Function + The input function, which should not be polymorphic or be higher-order. + This is because all types must be known and we can't encode function arguments + to the program itself. + + mod : tvm.IRModule + The IRModule containing function and type definitions, + which is also mutated during this pass. + + Returns + ------- + expr : tvm.relay.Function + The output function. + """ + return _ffi_api.Defunctionalization(func, mod) def to_cps(func, mod=None): """ diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc new file mode 100644 index 000000000000..ec614d23a02e --- /dev/null +++ b/src/relay/transforms/defunctionalization.cc @@ -0,0 +1,431 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file defunctionalization.cc + * + * \brief Defunctionalization for Relay IR + * + * This pass transforms a higher-order program into a first-order program with defunctionalization. + * This means that all higher order functions (i.e functions that take function arguments or return + * functions) should be transformed into a semantically equivalent first order one. + * + * This pass implements a basic typed defunctionalization method. + * All higher order functions are cloned and specialized (so that there are no type params). + * Function type arguments are encoded as datatypes and a helper `apply` function is used + * to "call" them. + * + * For example, take the following higher order program: + * fun map F y = case y of + * Nil => Nil + * | Cons(x, XS) => Cons(F z, map F XS) + * fun addone 1 = map (\x -> \x + 1) 1 + * + * where `addone` is our program. + * When we call the `map` function, we see that it is a higher-order function, + * but we can clone `map ` function and specialize it with the type_params of the call. + * In addition, our function argument `(\x -> \x + 1)` will be encoded as a datatype constructor, + * which we will call `incr`, and all calls to `F` in our specialized map function will use the + * helper `apply` function. + * + * After defunctionalization, we get: + * fun apply encoding arg = case encoding of + * “incr” => incr arg + * fun map’ F y = case y of + * Nil => Nil + * | Cons(x, xs) => Cons(apply F x, map’ F xs) + * fun addone 1 = map’ “incr” 1 + * + * Currently, defunctionalization makes the following assumptions: + * - functions cannot return function values + * - function arguments are in two forms: identifier or a lambda abstraction + * - no functions stored in datatype + * - functions are not let binded + */ + +#include +#include +#include +#include +#include +#include + +#include "../analysis/type_solver.h" +#include "../transforms/pass_util.h" +namespace tvm { +namespace relay { + +// determine if type contains a FuncType +bool HasFuncType(const Type& t) { + struct FuncTypeVisitor : TypeVisitor { + bool has_func_type; + FuncTypeVisitor() : has_func_type(false) {} + + void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; } + }; + + auto visitor = FuncTypeVisitor(); + visitor.VisitType(t); + return visitor.has_func_type; +} +// determine if FuncType is a higher order type +bool IsHigherOrderFunc(const FuncType& t) { + bool higher_order = false; + for (auto arg : t->arg_types) { + higher_order |= HasFuncType(arg); + } + return higher_order |= HasFuncType(t->ret_type); +} + +/*! + * \brief mutator for driving the Defunctionalization transformation + */ +class DefuncMutator : public ExprMutator { + public: + explicit DefuncMutator(const IRModule& mod) : mod(mod), constructor_counter(0) {} + + Expr VisitExpr_(const CallNode* call) { + if (auto op = call->op.as()) { + CHECK_EQ(call->type_args.size(), op->checked_type().as()->type_params.size()) + << "all type args must be explicit"; + + auto op_type = InstFuncType(op->checked_type().as(), call->type_args); + CHECK_EQ(FreeTypeVars(op_type, mod).size(), 0) << "free type vars in instantiated"; + CHECK(!HasFuncType(op_type->ret_type)) << "returning functions not supported"; + + if (!IsHigherOrderFunc(op_type)) { + // not higher order function + return ExprMutator::VisitExpr_(call); + } + + // first we encode function arguments + Array args; + for (size_t i = 0; i < call->args.size(); i++) { + auto arg = call->args[i]; + auto type = op_type->arg_types[i]; + if (!HasFuncType(type)) { + args.push_back(arg); + } else { + args.push_back(EncodeArg(arg, type)); + } + } + auto name = op->name_hint + TypeToString(op_type); + auto gv = GlobalVar(name); + if (specialized_gv_map.count(name)) { + gv = specialized_gv_map[name]; + } else { + specialized_gv_map[name] = gv; + // clone and specialize with specific type + auto clone = Downcast(DeDup(mod->Lookup(GetRef(op)))); + auto specialized_function = Specialize(clone, call->type_args); + // change var types and change all applications to use `apply` method + auto f = Downcast(FirstifyVars(specialized_function)); + mod->Add(gv, f); + } + return Call(gv, args); + } else if (auto op = call->op.as()) { + // reduction by applying vars + std::unordered_map var_binding_map; + for (size_t i = 0; i < op->params.size(); i++) { + var_binding_map[op->params[i]] = call->args[i]; + } + auto e = Bind(op->body, var_binding_map); + return this->VisitExpr(e); + } else if (auto op = call->op.as()) { + // var node will be encoded as datatype + // so we need to use the `apply` helper method + auto var_original_type = GetUnencodedType(op->type_annotation).as(); + CHECK(var_original_type) << "var original type not saved in var_save_type map"; + auto op_type = InstFuncType(var_original_type, call->type_args); + + Array args = {GetRef(op)}; + for (auto arg : call->args) { + args.push_back(this->VisitExpr(arg)); + } + + return Call(GetApplyFunction(op_type), args); + } + return ExprMutator::VisitExpr_(call); + } + + private: + // module + IRModule mod; + // gv + str(type) to specialized clone gv + std::unordered_map specialized_gv_map; + // str(func_type) to ADT + std::unordered_map func_encoding; + // str(func_tyoe) to apply gv + std::unordered_map apply_map; + // encoded ADT handle to FuncType + std::unordered_map original_func_type_map; + // gv to (str(func_type) to constructor encoding) + std::unordered_map, ObjectHash, + ObjectEqual> + gv_datatype_map; + // use monotonically increasing integer to represent new constructor_name + uint64_t constructor_counter; + + /*! + * \brief add a constructor to the GlobalTypeVar, creating a new TypeDef if GlobalTypeVar does not + * exist + */ + void AddConstructor(GlobalTypeVar gtv, Constructor c) { + if (!mod->ContainGlobalTypeVar(gtv->name_hint)) { + mod->AddTypeDef(gtv, TypeData(gtv, {}, {c})); + } else { + auto typedata = mod->LookupTypeDef(gtv); + auto constructors = typedata->constructors; + constructors.push_back(c); + mod->UpdateTypeDef(gtv, TypeData(typedata->header, typedata->type_vars, constructors)); + } + } + /*! + * \brief add a case to the apply function, creating the function if it does not exist + * + * \param apply_gv GlobalVar of the apply function + * \param ft is the type functions the apply function handles + * \param c constructor to add a case for + * \param expr calls this expr with the args to the apply_gv + * \param patterns PatterVars to match with the constructor, used for handling free vars in + * functions + */ + void AddApplyCase(GlobalVar apply_gv, FuncType ft, Constructor c, const Expr& expr, + const Array patterns) { + CHECK(c->inputs.size() == patterns.size()) + << "constructor function and pattern vars have different sizes"; + if (!mod->ContainGlobalVar(apply_gv->name_hint)) { + auto x = Var("x", TypeCall(c->belong_to, {})); + auto vars = Array({x}); + auto args = Array(); + for (auto t : ft->arg_types) { + auto y = Var("y", t); + vars.push_back(y); + args.push_back(y); + } + + auto clauses = Array({Clause(PatternConstructor(c, patterns), Call(expr, args))}); + auto body = Match(x, clauses); + auto f = Function(vars, body, ft->ret_type, {}); + + mod->Add(apply_gv, f); + } else { + auto f = Downcast(mod->Lookup(apply_gv)); + auto body = f->body.as(); + CHECK(body) << "internal invariant broken; apply function body should be a match node"; + + auto clauses = body->clauses; + auto x = f->params[0]; + auto args = Array(); + for (size_t i = 1; i < f->params.size(); i++) { + args.push_back(f->params[i]); + } + clauses.push_back(Clause(PatternConstructor(c, patterns), Call(expr, args))); + + mod->Add(apply_gv, Function(f->params, Match(x, clauses), f->ret_type, f->type_params), true); + } + } + + Expr EncodeArg(const Expr& arg, const Type& type) { + // we assume arg is either an identifier (var or globalvar) or a function + CHECK(type.as()) << "assume no nested functions"; + CHECK(arg.as() || arg.as() || arg.as()) + << "assume all first-order-parameters are identifiers or functions"; + + if (arg.as()) { + // variable with functype will be encoded as datatype in surrounding function + return arg; + } else if (arg.as()) { + return EncodeGlobalVar(Downcast(arg), Downcast(type)); + } else if (auto fn = arg.as()) { + // we handle free vars in anonymous functions by adding arguments to + // the constructor function + auto free_vars = FreeVars(arg); + auto ft = Downcast(type); + + auto arg_types = Array(); + auto pattern_vars = Array(); + auto call_args = Array(); + Map free_var_bind_map; + for (auto free_var : free_vars) { + // free vars are already encoded, can only exist within + // specialized functions + if (free_var->type_annotation.defined()) { + arg_types.push_back(free_var->type_annotation); + } else { + arg_types.push_back(free_var->checked_type()); + } + auto new_var = Var(free_var->name_hint(), free_var->type_annotation); + free_var_bind_map.Set(free_var, new_var); + pattern_vars.push_back(PatternVar(new_var)); + call_args.push_back(free_var); + } + auto gtv = GetFuncEncode(ft); + auto c = Constructor(std::to_string(++constructor_counter), arg_types, gtv); + AddConstructor(gtv, c); + + auto apply_gv = GetApplyFunction(ft); + auto body = this->VisitExpr(Bind(fn->body, free_var_bind_map)); + AddApplyCase(apply_gv, ft, c, Function(fn->params, body, fn->ret_type, fn->type_params), + pattern_vars); + + return Call(c, call_args); + } + + throw std::runtime_error("EncodeArg failed to cast arg into identifier node or function node"); + } + + /*! + * \brief encode a global var with a specialized type with a datatype + */ + Expr EncodeGlobalVar(const GlobalVar& gv, const FuncType& ft) { + auto map = gv_datatype_map[gv]; + auto type_key = TypeToString(ft); + if (map.count(type_key) == 0) { + auto gtv = GetFuncEncode(ft); + auto c = Constructor(std::to_string(constructor_counter++), {}, gtv); + map[type_key] = c; + AddConstructor(gtv, c); + AddApplyCase(GetApplyFunction(ft), ft, c, gv, {}); + } + return Call(map[type_key], {}); + } + + /*! + * \brief type to string + */ + std::string TypeToString(const Type& t) { + std::ostringstream s; + s << t; + return s.str(); + } + + /*! + * \brief get ADT handle for encoding type t + */ + GlobalTypeVar GetFuncEncode(const Type& t) { + auto adt_name = "Defunc" + TypeToString(t); + if (func_encoding.count(adt_name) == 0) { + func_encoding[adt_name] = GlobalTypeVar(adt_name, TypeKind::kAdtHandle); + } + original_func_type_map[func_encoding[adt_name]] = t; + return func_encoding[adt_name]; + } + + /*! + * \brief get original function type represented by type t + */ + FuncType GetUnencodedType(const Type& t) { + auto tc = t.as(); + CHECK(tc) << "expected type call when getting original type from encoded type"; + auto gv = tc->func.as(); + CHECK(gv) << "expected global type var in encoded type"; + auto type = original_func_type_map[GetRef(gv)]; + CHECK(type.defined()) << "reverse mapping from encoded type to original type not found"; + return Downcast(type); + } + + /*! + * \brief get the apply function for calling datatypes encoding functions of type t + */ + GlobalVar GetApplyFunction(const Type& t) { + auto f_name = "apply" + TypeToString(t); + if (apply_map.count(f_name) == 0) { + apply_map[f_name] = GlobalVar("apply" + TypeToString(t)); + } + return apply_map[f_name]; + } + + /*! + * \brief specialize a function type + */ + FuncType InstFuncType(const FuncTypeNode* fty, const Array type_args) { + CHECK(fty) << "InstFuncType functype is null"; + CHECK_EQ(fty->type_params.size(), type_args.size()) + << "size mismatch between function type params and type args"; + auto map = tvm::Map(); + for (size_t i = 0; i < type_args.size(); i++) { + map.Set(fty->type_params[i], type_args[i]); + } + // copy with typevars removed + return Downcast(TypeSubst(FuncType(fty->arg_types, fty->ret_type, {}, {}), map)); + } + + /*! + * \brief specialize a function expression + */ + Function Specialize(const Function& f, const Array type_args) { + CHECK_EQ(f->type_params.size(), type_args.size()) + << "cannot specialize function with size mismatch between function type params and type " + "args"; + auto map = tvm::Map(); + for (size_t i = 0; i < type_args.size(); i++) { + map.Set(f->type_params[i], type_args[i]); + } + // copy with typevars removed + auto copy = TypeSubst(Function(f->params, f->body, f->ret_type, {}), map); + return Downcast(copy); + } + + /*! + * \brief transform a function to be first order by transforming arg_types and + * using the `apply` function for applications + */ + Function FirstifyVars(const Function& f) { + CHECK(f->type_params.size() == 0) << "firstify function has type params"; + + tvm::Map var_bind_map; + Array params; + for (auto var : f->params) { + if (auto var_type = var->type_annotation.as()) { + // first order parameter + auto fop_type = GetRef(var_type); + auto adt = GetFuncEncode(fop_type); + auto new_var = Var(var->name_hint(), TypeCall(adt, {})); + mod->LookupTypeDef(adt); + var_bind_map.Set(var, new_var); + params.push_back(new_var); + } else { + CHECK(!HasFuncType(var->type_annotation)) + << "nested function type in parameter not supported yet"; + params.push_back(var); + } + } + + auto bind = Downcast(Bind(f, var_bind_map)); + return Function(params, this->VisitExpr(bind->body), bind->ret_type, {}); + } +}; + +Expr Defunctionalization(const Function& f, const IRModule& mod) { + // f is the starting point of the program, all types MUST be known + CHECK(f->type_params.size() == 0) << "no polymorphism supported for defunctionalization"; + for (const auto& p : f->params) { + CHECK(!HasFuncType(p->checked_type())) << "program cannot have func type parameters"; + } + CHECK(!HasFuncType(f->ret_type)) << "return type cannot contain function"; + + return Downcast(DefuncMutator(mod).VisitExpr(f)); +} + +TVM_REGISTER_GLOBAL("relay._transform.Defunctionalization").set_body_typed(Defunctionalization); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_defunctionalization.py b/tests/python/relay/test_pass_defunctionalization.py new file mode 100644 index 000000000000..ac54ebff0271 --- /dev/null +++ b/tests/python/relay/test_pass_defunctionalization.py @@ -0,0 +1,226 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import numpy as np + +import tvm +from tvm import relay +from tvm.relay.backend.interpreter import ConstructorValue +from tvm.relay import transform, ExprVisitor, TypeVisitor +from tvm.relay.testing import Prelude + +# determine if type t is a FuncType or has a nested FuncType +def has_func_type(t): + class FuncTypeVisitor(TypeVisitor): + def __init__(self): + super().__init__() + self.has_func = False + + def visit_func_type(self, ftt): + self.has_func = True + + ftvisitor = FuncTypeVisitor() + ftvisitor.visit(t) + return ftvisitor.has_func + +# determine whether a program has any higher order functions +# a higher order function is defined as one that: +# - has function type arguments +# - returns a function +def assert_no_higher_order_functions(expr, mod): + class CheckFirstOrderVisitor(ExprVisitor): + def __init__(self, mod): + super().__init__() + self.mod = mod + self.hof = [] + self.visited_gv = set() + + def visit_call(self, call): + is_higher_order = False + # check return type + if (has_func_type(call.checked_type)): + is_higher_order = True + # check argument types + for a in call.args: + if (has_func_type(a.checked_type)): + is_higher_order = True + # if it is higher order, save it for debugging later + if is_higher_order: + self.hof.append(call) + super().visit_call(call) + + def visit_global_var(self, gv): + # visit global vars to visit entire program + if gv not in self.visited_gv: + self.visited_gv.add(gv) + self.visit(self.mod[gv]) + + mod = transform.InferType()(mod) + check_fo_visitor = CheckFirstOrderVisitor(mod) + check_fo_visitor.visit(expr) + + nl = '\n--------\n' + errmsg = f"""found {len(check_fo_visitor.hof)} higher order functions: + {nl.join(expr.astext() for expr in check_fo_visitor.hof)}""" + + assert len(check_fo_visitor.hof) == 0, errmsg + +# assert that a program is defunctionalized and returns +# defunctionalized module +# assumes program starts from mod['main'] +def defunctionalized(mod): + mod = transform.InferType()(mod) + mod['main'] = transform.Defunctionalization(mod['main'], mod) + mod = transform.InferType()(mod) + assert_no_higher_order_functions(mod['main'], mod) + + return mod + +# adt list to python list +def to_list(mod, l): + list = mod.get_global_type_var('List') + list_adt = mod[list] + cons = list_adt.constructors[0] + nil = list_adt.constructors[1] + + assert isinstance(l, ConstructorValue) + val = l + ret = [] + while True: + if val.tag == cons.tag: + ret.append(val.fields[0].asnumpy()) + val = val.fields[1] + else: + assert val.tag == nil.tag + break + return ret + +# list to adt list +def to_adt_list(mod, arr): + expr = mod['main'] + l = mod.get_global_type_var('List') + list_adt = mod[l] + cons = list_adt.constructors[0] + nil = list_adt.constructors[1] + + li = nil() + for a in arr: + li = cons(relay.const(a), li) + ex = relay.create_executor(mod=mod) + adt = ex.evaluate(li) + mod['main'] = expr + return adt + +def test_simple(): + code = """ +#[version = "0.0.5"] +def @simple[A, B](%f: fn(A) -> B, %xs: A) -> B { + %f(%xs) +} +def @main(%l: Tensor[(5, 5), float32]) -> Tensor[(5, 5), float32] { + %0 = fn[A](%x: A) -> A { + %x + }; + @simple(%0, %l) +} +""" + mod = tvm.parser.fromtext(code) + defunc_mod = defunctionalized(mod) + + input = np.random.rand(5,5).astype('float32') + + ex = relay.create_executor('debug', mod=mod) + defunc_ex = relay.create_executor('debug', mod=defunc_mod) + + out = ex.evaluate()(input) + defunc_out = defunc_ex.evaluate()(input) + + np.testing.assert_equal(out.asnumpy(), defunc_out.asnumpy()) + + +def test_global_recursion(): + code = """ +#[version = "0.0.5"] +type List[A] { + Cons(A, List[A]), + Nil, +} +def @id[A](%x: A) -> A { + %x +} +def @map[A, B](%f: fn(A) -> B, %xs: List[A]) -> List[B] { + match (%xs) { + Cons(%x, %rest) => Cons(%f(%x), @map(%f, %rest)), + Nil => Nil, + } +} +def @main(%l: List[float32]) -> List[float32] { + @map(@id, %l) +} +""" + mod = tvm.parser.fromtext(code) + defunc_mod = defunctionalized(mod) + + input = np.random.rand(10).astype('float32') + + ex = relay.create_executor('debug', mod=mod) + defunc_ex = relay.create_executor('debug', mod=defunc_mod) + + out = ex.evaluate(mod['main'])(to_adt_list(mod, input)) + defunc_out = defunc_ex.evaluate()(to_adt_list(defunc_mod, input)) + + np.testing.assert_array_equal(to_list(mod, out), to_list(defunc_mod, defunc_out)) + +def test_recursive_datatype(): + # CPS will create recursive datatype + code = """ +#[version = "0.0.5"] +type List[A] { + Cons(A, List[A]), + Nil, +} +def @sum(%f: fn(int32) -> int32, %k: List[int32]) -> int32 { + match (%k) { + Cons(%x, %rest) => %0 = fn(%n) { + %x + %f(%n) + }; + @sum(%0, %rest), + Nil => %f(0), + } +} +def @id[A](%x: A) -> A { + %x +} +def @main(%l: List[int32]) -> int32 { + @sum(@id, %l) +} +""" + mod = tvm.parser.fromtext(code) + defunc_mod = defunctionalized(mod) + + input = np.random.randint(1, 100, 10) + + ex = relay.create_executor('debug', mod=mod) + defunc_ex = relay.create_executor('debug', mod=defunc_mod) + + out = ex.evaluate(mod['main'])(to_adt_list(mod, input)) + defunc_out = defunc_ex.evaluate()(to_adt_list(defunc_mod, input)) + + tvm.testing.assert_allclose(out.asnumpy(), defunc_out.asnumpy()) + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file