diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index d322710ec95a..4f57e956b0c9 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -210,6 +210,16 @@ TVM_DLL Pass FastMath(); */ TVM_DLL Pass InferType(); +/*! + * \brief Infer the type of all functions in a module. + * + * This pass should be used when typechecking modules + * with mutually recursive functions. + * + * \return The pass. + */ +TVM_DLL Pass InferTypeAll(); + /*! * \brief Search and eliminate common subexpression. For example, if there are * two expressions evaluated to an identical value, a single variable is created @@ -493,6 +503,15 @@ TVM_DLL Function UnCPS(const Function& f); */ TVM_DLL Expr DeDup(const Expr& e); +/*! + * \brief Deduplicate the bound type variables in the type. + * + * \param e the type (does not have to be typechecked). + * + * \return the deduplicated type. + */ +TVM_DLL Type DeDupType(const Type& e); + } // namespace relay } // namespace tvm diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 8d75d8e8ee21..ca7f5a5b1120 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -87,6 +87,10 @@ def _add(self, var, val, update=False): var = _ty.GlobalTypeVar(var) _ffi_api.Module_AddDef(self, var, val, update) + def add_unchecked(self, var, val): + assert isinstance(val, _expr.RelayExpr) + _ffi_api.Module_AddUnchecked(self, var, val) + def __getitem__(self, var): """Lookup a global definition by name or by variable. diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index cc92141b73db..2d1ea829a92b 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -98,6 +98,8 @@ def InferType(): """ return _ffi_api.InferType() +def InferTypeAll(): + return _ffi_api.InferTypeAll() def FoldScaleAxis(): """Fold the scaling of axis into weights of conv2d/dense. This pass will diff --git a/src/ir/module.cc b/src/ir/module.cc index bcab39aabf32..55b1beae3b96 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -403,6 +403,9 @@ TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method(&IRModuleNode::AddTypeDef); +TVM_REGISTER_GLOBAL("ir.Module_AddUnchecked") + .set_body_method(&IRModuleNode::AddUnchecked); + TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar") .set_body_method(&IRModuleNode::GetGlobalVar); diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h index dcd8de075854..37d8e45bf279 100644 --- a/src/relay/analysis/type_solver.h +++ b/src/relay/analysis/type_solver.h @@ -65,6 +65,9 @@ class TypeSolver { public: TypeSolver(const GlobalVar& current_func, const IRModule& _mod, ErrorReporter* err_reporter); ~TypeSolver(); + + void SetCurrentFunc(const GlobalVar& current_func) { this->current_func = current_func; } + /*! * \brief Add a type constraint to the solver. * \param constraint The constraint to be added. diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index d90e5c584df3..6bc7873cfab8 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -99,6 +99,33 @@ Expr DeDup(const Expr& e) { return ret; } +// dedup bound type variables in type +// - types do not have to be already typechecked +Type DeDupType(const Type& e) { + class DeDupTypeMutator : public TypeMutator, public ExprMutator, public PatternMutator { + public: + TypeVar Fresh(const TypeVar& tv) { + TypeVar ret = TypeVar(tv->name_hint, tv->kind); + type_rename_[tv] = ret; + return ret; + } + + Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; } + + Pattern VisitPattern(const Pattern& p) final { return PatternFunctor::VisitPattern(p); } + + Type VisitType_(const TypeVarNode* op) final { + TypeVar v = GetRef(op); + return type_rename_.count(v) != 0 ? type_rename_.at(v) : Fresh(v); + } + + private: + std::unordered_map type_rename_; + }; + + Type ret = DeDupTypeMutator().VisitType(e); + return ret; +} TVM_REGISTER_GLOBAL("relay._transform.dedup").set_body_typed(DeDup); } // namespace relay diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 7182f0e96f0f..3fc871435a6a 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -109,6 +109,44 @@ class TypeInferencer : private ExprFunctor, // inference the type of expr. Expr Infer(Expr expr); + void SetCurrentFunc(GlobalVar current_func) { + this->current_func_ = current_func; + this->solver_.SetCurrentFunc(current_func); + } + + void Solve(); + Expr ResolveType(Expr expr); + + // Lazily get type for expr + // expression, we will populate it now, and return the result. + Type GetType(const Expr& expr) { + auto it = type_map_.find(expr); + if (it != type_map_.end() && it->second.checked_type.defined()) { + if (expr.as() != nullptr) { + // if we don't dedup GlobalVarNode, two functions that use the same GlobalVar + // may resolve to the same type incorrectly + return DeDupType(it->second.checked_type); + } + return it->second.checked_type; + } + Type ret = this->VisitExpr(expr); + CHECK(ret.defined()); + KindCheck(ret, mod_); + ResolvedTypeInfo& rti = type_map_[expr]; + rti.checked_type = ret; + return ret; + } + + // Lazily get type for GlobalVar + // expression, we will populate it now, and return the result. + Type GetTypeGlobalVar(const GlobalVar& expr) { + // we have to visit functiion + // or else it may not be type-checked + auto f = Downcast(mod_->Lookup(expr)); + Type ret = GetType(f); + return GetType(expr); + } + private: // type resolver that maps back to type class Resolver; @@ -143,21 +181,6 @@ class TypeInferencer : private ExprFunctor, } } - // Lazily get type for expr - // expression, we will populate it now, and return the result. - Type GetType(const Expr& expr) { - auto it = type_map_.find(expr); - if (it != type_map_.end() && it->second.checked_type.defined()) { - return it->second.checked_type; - } - Type ret = this->VisitExpr(expr); - CHECK(ret.defined()); - KindCheck(ret, mod_); - ResolvedTypeInfo& rti = type_map_[expr]; - rti.checked_type = ret; - return ret; - } - void ReportFatalError(const ObjectRef& expr, const Error& err) { CHECK(this->current_func_.defined()); this->err_reporter.ReportAt(this->current_func_, expr, err); @@ -303,7 +326,6 @@ class TypeInferencer : private ExprFunctor, this->ReportFatalError(match, ss); } } - return rtype; } @@ -543,14 +565,6 @@ class TypeInferencer : private ExprFunctor, } return FuncType(c->inputs, TypeCall(c->belong_to, types), td->type_vars, {}); } - - void Solve() { - solver_.Solve(); - - if (err_reporter.AnyErrors()) { - err_reporter.RenderErrors(mod_); - } - } }; class TypeInferencer::Resolver : public ExprMutator, PatternMutator { @@ -698,6 +712,20 @@ Expr TypeInferencer::Infer(Expr expr) { return resolved_expr; } +void TypeInferencer::Solve() { + solver_.Solve(); + + if (err_reporter.AnyErrors()) { + err_reporter.RenderErrors(mod_); + } +} + +Expr TypeInferencer::ResolveType(Expr expr) { + auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr); + CHECK(WellFormed(resolved_expr)); + return resolved_expr; +} + struct AllCheckTypePopulated : ExprVisitor { void VisitExpr(const Expr& e) { if (e.as()) { @@ -742,6 +770,43 @@ Function InferType(const Function& func, const IRModule& mod, const GlobalVar& v return Downcast(func_ret); } +IRModule InferTypeAll(const IRModule& mod) { + CHECK(mod.defined()) << "internal error: module must be set for type inference"; + const auto& globalvars = mod->GetGlobalVars(); + + // first pass: fill in type for all functions + for (const auto& var : globalvars) { + relay::Function func = Downcast(mod->Lookup(var)); + func->checked_type_ = func->func_type_annotation(); + mod->AddUnchecked(var, func); + } + + // use dummy var, will be updated as we fill constraints/ + // solve for each GlobalVar + TypeInferencer ti = TypeInferencer(mod, GlobalVar("dummy")); + + // second pass, fill in constraints + for (const auto& var : globalvars) { + ti.SetCurrentFunc(var); + ti.GetTypeGlobalVar(var); + } + + // solve constraints + ti.Solve(); + + // third pass, resolve types + for (const auto& var : globalvars) { + ti.SetCurrentFunc(var); + + relay::Function func = Downcast(mod->Lookup(var)); + Expr func_ret = ti.ResolveType(func); + // add function back to module + mod->AddUnchecked(var, Downcast(func_ret)); + } + + return mod; +} + namespace transform { Pass InferType() { @@ -750,8 +815,18 @@ Pass InferType() { return CreateFunctionPass(pass_func, 0, "InferType", {}); } +Pass InferTypeAll() { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relay::InferTypeAll(m); }; + return CreateModulePass(pass_func, 0, "InferTypeAll", {}); +} + TVM_REGISTER_GLOBAL("relay._transform.InferType").set_body_typed([]() { return InferType(); }); +TVM_REGISTER_GLOBAL("relay._transform.InferTypeAll").set_body_typed([]() { + return InferTypeAll(); +}); + } // namespace transform } // namespace relay diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index cc4748c92b00..73dcc2cf9621 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -20,9 +20,12 @@ import pytest import tvm from tvm import te +import numpy as np +from tvm.relay.prelude import Prelude from tvm import relay from tvm.relay import op, transform, analysis from tvm.relay import Any +from tvm.relay.testing import add_nat_definitions def run_infer_type(expr, mod=None): if not mod: @@ -362,6 +365,148 @@ def test_let_polymorphism(): int32 = relay.TensorType((), "int32") tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])])) +def test_mutual_recursion(): + # f(x) = if x > 0 then g(x - 1) else 0 + # g(y) = if y > 0 then f(y - 1) else 0 + tensortype = relay.TensorType((), 'float32') + # we need to annotate with tensortype + # because binary op does add relations between operands + x = relay.Var("x", tensortype) + y = relay.Var("y", tensortype) + + zero = relay.Constant(tvm.nd.array(np.array(0, dtype='float32'))) + one = relay.Constant(tvm.nd.array(np.array(1, dtype='float32'))) + + f_gv = relay.GlobalVar('f') + g_gv = relay.GlobalVar('g') + + def body(var, call_func): + subtract_one = relay.op.subtract(var, one) + cond = relay.If(relay.op.greater(var, zero), + relay.Call(call_func, [subtract_one]), + zero) + func = relay.Function([var], cond) + return func + + f = body(x, g_gv) + g = body(y, f_gv) + + mod = tvm.IRModule() + # p = Prelude(mod) + mod.add_unchecked(f_gv, f) + mod.add_unchecked(g_gv, g) + mod = transform.InferTypeAll()(mod) + + expected = relay.FuncType([tensortype], tensortype) + tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected) + tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected) + +def test_mutual_recursion_adt(): + # f[A](x: A) = match x { + # Cons(a, Nil) => a + # Cons(_, b) => g(b) + # } + # g[B](y: B) = match y { + # Cons(a, Nil) => a + # Cons(_, b) => f(b) + # } + p = Prelude() + l = p.l + + A = relay.TypeVar("A") + B = relay.TypeVar("B") + + x = relay.Var("x") + y = relay.Var("y") + + f_gv = relay.GlobalVar('f') + g_gv = relay.GlobalVar('g') + + def body(var, call_func, type_param): + a = relay.Var("a", type_param) + b = relay.Var("b") + body = relay.Match( + var, + [ + relay.Clause(relay.PatternConstructor(p.cons, [relay.PatternVar(a), relay.PatternConstructor(p.nil)]), a), + relay.Clause(relay.PatternConstructor(p.cons, [relay.PatternWildcard(), relay.PatternVar(b)]), relay.Call(call_func, [b])) + ], + complete=False + ) + func = relay.Function([var], body, type_params=[type_param]) + return func + + f = body(x, g_gv, A) + g = body(y, f_gv, B) + + mod = p.mod + mod.add_unchecked(f_gv, f) + mod.add_unchecked(g_gv, g) + mod = transform.InferTypeAll()(mod) + + tv = relay.TypeVar("test") + expected = relay.FuncType([l(tv)], tv, [tv]) + tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected) + tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected) + +def test_mutual_recursion_peano(): + # even and odd function for peano function + # even(x) = match x { + # z => true + # s(a: nat) => odd(a) + # } + # odd(x) = match x { + # z => false + # s(a: nat) => even(a) + # } + p = Prelude() + add_nat_definitions(p) + z = p.z + s = p.s + + even_gv = relay.GlobalVar('even') + odd_gv = relay.GlobalVar('odd') + + def create_even_func(odd_f): + x = relay.Var("x") + a = relay.Var("a") + true = tvm.nd.array(np.array(True)) + + body = relay.Match( + x, + [ + relay.Clause(relay.PatternConstructor(z), relay.Constant(true)), + relay.Clause(relay.PatternConstructor(s, [relay.PatternVar(a)]), relay.Call(odd_f, [a])) + ] + ) + func = relay.Function([x], body) + return func + def create_odd_func(even_f): + x = relay.Var("x") + a = relay.Var("a") + false = tvm.nd.array(np.array(False)) + + body = relay.Match( + x, + [ + relay.Clause(relay.PatternConstructor(z), relay.Constant(false)), + relay.Clause(relay.PatternConstructor(s, [relay.PatternVar(a)]), relay.Call(even_f, [a])) + ] + ) + func = relay.Function([x], body) + return func + + even_func = create_even_func(odd_gv) + odd_func = create_odd_func(even_gv) + + mod = p.mod + mod.add_unchecked(even_gv, even_func) + mod.add_unchecked(odd_gv, odd_func) + mod = transform.InferTypeAll()(mod) + + expected = relay.FuncType([p.nat()], relay.TensorType((), dtype='bool')) + tvm.ir.assert_structural_equal(mod[even_gv].checked_type, expected) + tvm.ir.assert_structural_equal(mod[odd_gv].checked_type, expected) def test_if(): choice_t = relay.FuncType([], relay.scalar_type('bool'))