From de2d5aa9a46aaa86ec618374265191c35b3b6665 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 15 Jan 2019 10:30:52 -0800 Subject: [PATCH 1/2] Revert "fix handling a tuple node in op fusion (#2433)" This reverts commit 749cb2156520da84d77673e48002d1fac6f0a84e. From 00e5f638bd86700855c3240371f8046f6abab086 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 15 Jan 2019 10:30:52 -0800 Subject: [PATCH 2/2] Revert "[Relay] Expand type unification and other utilities (#2189)" This reverts commit e0a20ad4a5b6f007f7182f12958e061c873a6396. --- include/tvm/relay/pass.h | 68 ----- python/tvm/relay/ir_pass.py | 68 +---- src/relay/pass/type_infer.cc | 137 ++++----- src/relay/pass/type_solver.cc | 343 +++------------------- src/relay/pass/type_solver.h | 31 +- src/relay/pass/util.cc | 214 +++----------- tests/cpp/relay_pass_type_infer_test.cc | 16 +- tests/python/relay/test_pass_free_vars.py | 41 +++ tests/python/relay/test_pass_vars.py | 144 --------- tests/python/relay/test_type_infer.py | 81 ++--- tests/python/relay/test_type_solver.py | 164 ----------- 11 files changed, 250 insertions(+), 1057 deletions(-) create mode 100644 tests/python/relay/test_pass_free_vars.py delete mode 100644 tests/python/relay/test_pass_vars.py diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 566d69cc6b0b..1897809f48b8 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -108,17 +108,6 @@ bool AlphaEqual(const Type& t1, const Type& t2); */ bool WellFormed(const Expr& expr); -/*! \brief Get all bound variables from expression expr. - * - * Bound variables are all variables that are declared in the expr. - * They only have meaning inside that expr, and can only be used in it. - * - * \param expr the expression. - * - * \return List of bound vars, in the PostDFS order in the expression. - */ -tvm::Array BoundVars(const Expr& expr); - /*! \brief Get free type parameters from expression expr. * * Free variables are variables that are not bound by a @@ -130,14 +119,6 @@ tvm::Array BoundVars(const Expr& expr); */ tvm::Array FreeVars(const Expr& expr); -/*! \brief Get all variables from expression expr. - * - * \param expr the expression. - * - * \return List of all vars, in the PostDFS order in the expression. - */ -tvm::Array AllVars(const Expr& expr); - /*! \brief Get free TypeVars from expression expr. * * Free type parameters are type parameters that are not bound by a function @@ -149,55 +130,6 @@ tvm::Array AllVars(const Expr& expr); */ tvm::Array FreeTypeVars(const Expr& expr); -/*! \brief Get free TypeVars from type t. - * - * Free type parameters are type parameters that are not bound by a function - * type in the context. - * - * \param t the type. - * - * \return List of free type vars, in the PostDFS order visited by type. - */ -tvm::Array FreeTypeVars(const Type& t); - -/*! \brief Get all bound type variables from expression expr. - * - * Bound variables are all type variables that are declared in the expr. - * They only have meaning inside that expr, and can only be used in it. - * - * \param expr the expression. - * - * \return List of bound type vars, in the PostDFS order in the expression. - */ -tvm::Array BoundTypeVars(const Expr& expr); - -/*! \brief Get all bound type variables from type t. - * - * Bound variables are all type variables that are declared in the type. - * They only have meaning inside that type, and can only be used in it. - * - * \param t the type - * - * \return List of bound type vars, in the PostDFS order visited by type. - */ -tvm::Array BoundTypeVars(const Type& t); - -/*! \brief Get all type variables in expression expr. - * - * \param expr the expression. - * - * \return List of type vars, in the PostDFS order in the expression. - */ -tvm::Array AllTypeVars(const Expr& expr); - -/*! \brief Get all type variables in type t. - * - * \param t the type. - * - * \return List of type vars, in the PostDFS order visited by type. - */ -tvm::Array AllTypeVars(const Type& t); - /*! \brief Remove expressions which does not effect the program result. * * It will remove let bindings which are not referenced, and branches that will diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index d5d5e9261fc7..1bec7ccd72d5 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -158,38 +158,6 @@ def free_vars(expr): return _ir_pass.free_vars(expr) -def bound_vars(expr): - """Get bound vars from expression expr in post-DFS order. - - Parameters - ---------- - expr: tvm.relay.Expr - The input expression - - Returns - ------- - free : List[tvm.relay.Var] - The list of bound variables in post-DFS order. - """ - return _ir_pass.bound_vars(expr) - - -def all_vars(expr): - """Get all vars from expression expr in post-DFS order. - - Parameters - ---------- - expr: tvm.relay.Expr - The input expression - - Returns - ------- - free : List[tvm.relay.Var] - The list of all variables in post-DFS order. - """ - return _ir_pass.all_vars(expr) - - def free_type_vars(expr): """Get free type variables from expression/type e @@ -200,44 +168,12 @@ def free_type_vars(expr): Returns ------- - free : List[tvm.relay.TypeVar] - The list of free type variables in post-DFS order + free : List[tvm.relay.TypeParam] + The list of free type variables """ return _ir_pass.free_type_vars(expr) -def bound_type_vars(expr): - """Get bound type variables from expression/type e - - Parameters - ---------- - expr: Union[tvm.relay.Expr,tvm.relay.Type] - The input expression/type - - Returns - ------- - free : List[tvm.relay.TypeVar] - The list of bound type variables in post-DFS order - """ - return _ir_pass.bound_type_vars(expr) - - -def all_type_vars(expr): - """Get all type variables from expression/type e - - Parameters - ---------- - expr: Union[tvm.relay.Expr,tvm.relay.Type] - The input expression/type - - Returns - ------- - free : List[tvm.relay.TypeVar] - The list of all type variables in post-DFS order - """ - return _ir_pass.all_type_vars(expr) - - def simplify_inference(expr): """ Simplify the data-flow graph for inference phase. diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index af4cc6607a44..ee1b5ab10148 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -56,11 +56,31 @@ bool TupleGetItemRel(const Array& types, return true; } +bool MakeTupleRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(static_cast(num_inputs + 1), types.size()); + for (int i = 0; i < num_inputs; ++i) { + if (types[i].as()) return false; + } + Array fields; + for (int i = 0; i < num_inputs; ++i) { + fields.push_back(types[i]); + } + reporter->Assign(types[num_inputs], TupleTypeNode::make(fields)); + return true; +} + TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs); TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem") .set_body_typed&, int, const Attrs&, const TypeReporter&)>( TupleGetItemRel); +TVM_REGISTER_API("tvm.relay.type_relation.MakeTuple") +.set_body_typed&, int, const Attrs&, const TypeReporter&)>( + MakeTupleRel); + struct ResolvedTypeInfo { explicit ResolvedTypeInfo(Type checked_type, Array type_args) : checked_type(checked_type), type_args(type_args) {} @@ -100,10 +120,6 @@ class TypeInferencer : private ExprFunctor { // type inferencer will populate it up std::unordered_map type_map_; - // used to ensure we don't have free type vars hanging around - // (a temporary measure until we have proper generalization implemented) - Map instantiation_map_; - // The solver used by the inferencer. TypeSolver solver_; // relation function @@ -124,32 +140,6 @@ class TypeInferencer : private ExprFunctor { return Type(); } } - - // Substitutes every type var in t with a corresponding incomplete type. - // This is a temporary measure to ensure type vars behave until - // generalization is properly implemented. - Type Instantiate(const Type &t) { - if (!t.defined()) { - return t; - } - auto* ft = t.as(); - if (ft == nullptr) { - return Bind(t, instantiation_map_); - } - - for (auto type_param : ft->type_params) { - instantiation_map_.Set(type_param, IncompleteTypeNode::make(TypeVarNode::Kind::kType)); - } - - Type ret_type = ft->ret_type; - if (!ret_type.defined()) { - ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); - } - - auto strip_tvs = FuncTypeNode::make(ft->arg_types, ret_type, {}, ft->type_constraints); - return Bind(strip_tvs, instantiation_map_); - } - // Lazily get type for expr // will call visit to deduce it if it is not in the type_map_ Type GetType(const Expr &expr) { @@ -157,7 +147,7 @@ class TypeInferencer : private ExprFunctor { if (it != type_map_.end() && it->second.checked_type.defined()) { return it->second.checked_type; } - Type ret = Instantiate(this->VisitExpr(expr)); + Type ret = this->VisitExpr(expr); ResolvedTypeInfo& rti = type_map_[expr]; rti.checked_type = ret; return ret; @@ -185,11 +175,19 @@ class TypeInferencer : private ExprFunctor { } Type VisitExpr_(const TupleNode* op) final { + if (!make_tuple_rel_.defined()) { + make_tuple_rel_ = TypeRelationFn( + EnvFunc::Get("tvm.relay.type_relation.MakeTuple").node_); + } Array types; for (Expr field : op->fields) { types.push_back(GetType(field)); } - return TupleTypeNode::make(types); + Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + types.push_back(rtype); + solver_.AddConstraint(TypeRelationNode::make( + make_tuple_rel_, types, op->fields.size(), Attrs())); + return rtype; } Type VisitExpr_(const TupleGetItemNode* op) final { @@ -211,17 +209,11 @@ class TypeInferencer : private ExprFunctor { } Type VisitExpr_(const LetNode* op) final { - // if the definition is a function literal, permit recursion - bool is_functional_literal = op->value.as() != nullptr; - if (is_functional_literal) { - type_map_[op->var].checked_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); - } - Type vtype = GetType(op->value); if (op->var->type_annotation.defined()) { vtype = Unify(vtype, op->var->type_annotation, op->span); } - CHECK(is_functional_literal || !type_map_.count(op->var)); + CHECK(!type_map_.count(op->var)); // NOTE: no scoping is necessary because var are unique in program type_map_[op->var].checked_type = vtype; return GetType(op->body); @@ -260,14 +252,16 @@ class TypeInferencer : private ExprFunctor { return rtype; } - // substitute the type args in the function type - FuncType InstantiateFuncType(const FuncTypeNode* fn_ty, const Array& ty_args) { + // instantiate the function type with fresh + FuncType Instantiate(const FuncTypeNode* fn_ty, Array* ty_args) { tvm::Map subst_map; // Build a subsitituion map up from the function type and type arguments. // Eventually allow the type vars to be passed in. - for (size_t i = 0; i < fn_ty->type_params.size(); i++) { - subst_map.Set(fn_ty->type_params[i], ty_args[i]); + for (auto ty_param : fn_ty->type_params) { + IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind); + subst_map.Set(ty_param, fresh); + ty_args->push_back(fresh); } Type ret_type = fn_ty->ret_type; @@ -302,32 +296,13 @@ class TypeInferencer : private ExprFunctor { Type GeneralCall(const CallNode* call, Array arg_types) { Type ftype = GetType(call->op); auto* fn_ty_node = ftype.as(); - auto* inc_ty_node = ftype.as(); - - CHECK(fn_ty_node != nullptr || inc_ty_node != nullptr) - << "only expressions with function types can be called, found " - << ftype << " at " << call->span; - - // incomplete type => it must be a function taking the arg types - // with an unknown return type - if (inc_ty_node != nullptr) { - Type ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); - Type func_type = FuncTypeNode::make(arg_types, ret_type, {}, {}); - Type unified = this->Unify(ftype, func_type, call->span); - fn_ty_node = unified.as(); - } - Array type_args = call->type_args; - if (type_args.size() == 0) { - for (size_t i = 0; i < fn_ty_node->type_params.size(); i++) { - type_args.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType)); - } - } - CHECK(type_args.size() == fn_ty_node->type_params.size()) - << "Incorrect number of type args in " << call->span << ": " - << "Expected " << fn_ty_node->type_params.size() - << "but got " << type_args.size(); - FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args); + CHECK(fn_ty_node != nullptr) + << "only expressions with function types can be called, found " + << ftype << " at " << call->span; + + Array type_args; + FuncType fn_ty = Instantiate(fn_ty_node, &type_args); AddTypeArgs(GetRef(call), type_args); @@ -378,17 +353,26 @@ class TypeInferencer : private ExprFunctor { } Type VisitExpr_(const FunctionNode* f) final { - solver_.Solve(); - Array arg_types; for (auto param : f->params) { - arg_types.push_back(GetType(param)); + GetType(param); } Type rtype = GetType(f->body); - if (f->ret_type.defined()) { - rtype = this->Unify(f->ret_type, rtype, f->span); + // Run solver using the currently known information + solver_.Solve(); + // Trying to resolve + Array arg_types; + for (size_t i = 0; i < f->params.size(); ++i) { + Type atype = solver_.Resolve(GetType(f->params[i])); + CHECK(atype.as() == nullptr) + << "Cannot resolve type of " << i + << "-th parameter of function at" << f->span; + arg_types.push_back(atype); } - auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {}); - return solver_.Resolve(ret); + rtype = solver_.Resolve(rtype); + CHECK(rtype.as() == nullptr) + << "Cannot resolve return type of function at" << f->span; + // do not support constraint lifting for now. + return FuncTypeNode::make(arg_types, rtype, f->type_params, {}); } }; @@ -396,7 +380,7 @@ class TypeInferencer::Resolver : public ExprMutator { public: Resolver(const std::unordered_map& tmap, TypeSolver* solver) - : tmap_(tmap), solver_(solver) { + : tmap_(tmap), solver_(solver) { } Expr VisitExpr_(const VarNode* op) final { @@ -541,7 +525,6 @@ Expr TypeInferencer::Infer(Expr expr) { GetType(expr); // Step 1: Solve the constraints. solver_.Solve(); - // Step 2: Attach resolved types to checked_type field. auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr); CHECK(WellFormed(resolved_expr)); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index caea3755b8f9..e1efcbbdd0b9 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -5,7 +5,6 @@ */ #include #include "type_solver.h" -#include "../ir/type_functor.h" namespace tvm { namespace relay { @@ -39,298 +38,9 @@ class TypeSolver::Reporter : public TypeReporterNode { TypeSolver* solver_; }; -class TypeSolver::OccursChecker : public TypeVisitor { - public: - explicit OccursChecker(TypeSolver* solver, TypeNode* var) - : solver_(solver), var_(var), found_(false) {} - - bool Check(const Type& t) { - VisitType(t); - return found_; - } - - void VisitType_(const IncompleteTypeNode* op) override { - IncompleteType t = GetRef(op); - TypeNode* node = solver_->GetTypeNode(t); - found_ = found_ || (var_->FindRoot() == node->FindRoot()); - } - - private: - TypeSolver* solver_; - TypeNode* var_; - bool found_; -}; - -class TypeSolver::Unifier : public TypeFunctor { - public: - explicit Unifier(TypeSolver* solver) : solver_(solver) {} - - Type Unify(const Type& src, const Type& dst) { - // Known limitation - // - handle shape pattern matching - TypeNode* lhs = solver_->GetTypeNode(dst); - TypeNode* rhs = solver_->GetTypeNode(src); - - // do occur check so we don't create self-referencing structure - if (lhs->FindRoot() == rhs->FindRoot()) { - return lhs->resolved_type; - } - if (lhs->resolved_type.as()) { - CHECK(!CheckOccurs(lhs, rhs->resolved_type)) - << "Incomplete type " << lhs->resolved_type << " occurs in " - << rhs->resolved_type << ", cannot unify"; - solver_->MergeFromTo(lhs, rhs); - return rhs->resolved_type; - } else if (rhs->resolved_type.as()) { - CHECK(!CheckOccurs(rhs, lhs->resolved_type)) - << "Incomplete type " << rhs->resolved_type << " occurs in " - << lhs->resolved_type << ", cannot unify"; - solver_->MergeFromTo(rhs, lhs); - return lhs->resolved_type; - } else { - Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type); - CHECK(resolved.defined()) - << "Unable to unify parent types: " - << lhs->resolved_type << " and " << rhs->resolved_type; - TypeNode* top = solver_->GetTypeNode(resolved); - solver_->MergeFromTo(lhs, top); - solver_->MergeFromTo(rhs, top); - return resolved; - } - } - - // Checks whether lhs (taken to be a type var) occurs in t, meaning - // there is a recursive equality constraint, which should be rejected. - // N.b.: A tautology like ?a = ?a is okay and should be checked for - // *before* calling this method - bool CheckOccurs(TypeNode* lhs, const Type& t) { - OccursChecker rc(solver_, lhs); - return rc.Check(t); - } - - // default: unify only if alpha-equal - Type VisitTypeDefault_(const Node* op, const Type& tn) override { - NodeRef nr = GetRef(op); - Type t1 = GetRef(nr.as_derived()); - if (!AlphaEqual(t1, tn)) { - return Type(nullptr); - } - return t1; - } - - Type VisitType_(const TupleTypeNode* op, const Type& tn) override { - const auto* ttn = tn.as(); - if (!ttn || op->fields.size() != ttn->fields.size()) { - return Type(nullptr); - } - - TupleType tt1 = GetRef(op); - TupleType tt2 = GetRef(ttn); - - std::vector new_fields; - for (size_t i = 0; i < tt1->fields.size(); i++) { - Type field = Unify(tt1->fields[i], tt2->fields[i]); - new_fields.push_back(field); - } - return TupleTypeNode::make(new_fields); - } - - Type VisitType_(const FuncTypeNode* op, const Type& tn) override { - const auto* ftn = tn.as(); - if (!ftn - || op->arg_types.size() != ftn->arg_types.size() - || op->type_params.size() != ftn->type_params.size() - || op->type_constraints.size() != ftn->type_constraints.size()) { - return Type(nullptr); - } - - // remap type vars so they match - Map subst_map; - for (size_t i = 0; i < op->type_params.size(); i++) { - subst_map.Set(ftn->type_params[i], op->type_params[i]); - } - - auto ft1 = GetRef(op); - auto ft2 = Downcast(Bind(GetRef(ftn), subst_map)); - - Type ret_type = Unify(ft1->ret_type, ft2->ret_type); - - std::vector arg_types; - for (size_t i = 0; i < ft1->arg_types.size(); i++) { - Type arg_type = Unify(ft1->arg_types[i], ft2->arg_types[i]); - arg_types.push_back(arg_type); - } - - std::vector type_constraints; - for (size_t i = 0; i < ft1->type_constraints.size(); i++) { - Type unified_constraint = Unify(ft1->type_constraints[i], - ft2->type_constraints[i]); - const auto* tcn = unified_constraint.as(); - CHECK(tcn) << "Two type constraints unified into a non-constraint?" - << ft1->type_constraints[i] << " and " << ft2->type_constraints[i]; - type_constraints.push_back(GetRef(tcn)); - } - - return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints); - } - - private: - TypeSolver* solver_; -}; - -class TypeSolver::Resolver : public TypeMutator { - public: - explicit Resolver(TypeSolver* solver) : solver_(solver) {} - - Type Resolve(const Type& t) { - if (!t.defined()) { - return t; - } - return VisitType(t); - } - - Type VisitType_(const IncompleteTypeNode* op) override { - auto* node = solver_->GetTypeNode(GetRef(op)); - return node->resolved_type; - } - - private: - TypeSolver* solver_; -}; - -// It ends up being more compact to simply have TypeFunctor { - public: - explicit Propagator(TypeSolver* solver, const std::unordered_set* rels) - : solver_(solver), rels_(rels) {} - - // adds the relation node to t and all child types of t - void Propagate(const Type& t) { - VisitType(t); - } - - void UpdateRelSet(const Type& t) { - TypeNode* tnode = solver_->GetTypeNode(t); - for (auto* rel : *rels_) { - tnode->rel_set.insert(rel); - } - } - - void VisitTypeDefault_(const Node* op) override { - NodeRef nr = GetRef(op); - Type t = GetRef(nr.as_derived()); - UpdateRelSet(t); - } - - void VisitType_(const TupleTypeNode* op) override { - TupleType tt = GetRef(op); - UpdateRelSet(tt); - - for (const Type& t : tt->fields) { - Propagate(t); - } - } - - void VisitType_(const FuncTypeNode* op) override { - FuncType ft = GetRef(op); - UpdateRelSet(ft); - - Propagate(ft->ret_type); - for (auto arg_type : ft->arg_types) { - Propagate(arg_type); - } - - for (auto type_param : ft->type_params) { - Propagate(type_param); - } - - for (auto type_cs : ft->type_constraints) { - Propagate(type_cs); - } - } - - private: - TypeSolver* solver_; - const std::unordered_set* rels_; -}; - -// similarly, we use TypeFunctor so we can use -// the default visitor case to avoid more overrides -class TypeSolver::Merger : public TypeFunctor { - public: - explicit Merger(TypeSolver* solver) : solver_(solver) {} - - // Merges src node to dst, ensures *all* type relations of all - // child nodes of src are transferred to dst. - void Merge(TypeNode* src, TypeNode* dst) { - if (src == dst) return; - dst_ = dst; - VisitType(src->resolved_type); - // set parent at the end so later calls to GetTypeNode go back to src - src->parent = dst; - - // now propagate relations to child nodes, since change to - // a child node should update parent too - Propagator prop(solver_, &dst->rel_set); - prop.Propagate(dst->resolved_type); - } - - // Transfers any relations linked to t to the stored dst. - // Any unresolved relations are added back to the queue, since - // there is now new information - void TransferLinks(const Type& t) { - TypeNode* src = solver_->GetTypeNode(t); - if (src == dst_) return; - for (auto* rel : src->rel_set) { - // if the relation is not yet resolved, add to queue - if (!rel->resolved) { - solver_->AddToQueue(rel); - dst_->rel_set.insert(rel); - } - } - } - - void VisitTypeDefault_(const Node* op) override { - NodeRef nr = GetRef(op); - Type t = GetRef(nr.as_derived()); - TransferLinks(t); - } - - void VisitType_(const TupleTypeNode* ttn) override { - auto tup = GetRef(ttn); - TransferLinks(tup); - - for (auto field : tup->fields) { - VisitType(field); - } - } - - void VisitType_(const FuncTypeNode* ftn) override { - auto func = GetRef(ftn); - TransferLinks(func); - - VisitType(func->ret_type); - for (auto arg : func->arg_types) { - VisitType(arg); - } - for (auto param : func->type_params) { - VisitType(param); - } - for (auto constraint : func->type_constraints) { - VisitType(constraint); - } - } - - private: - TypeSolver* solver_; - TypeNode* dst_; -}; - // constructor TypeSolver::TypeSolver() - : reporter_(make_node(this)) { + : reporter_(make_node(this)) { } // destructor @@ -344,16 +54,31 @@ TypeSolver::~TypeSolver() { } } -// merge src type node to dst -void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) { - Merger merger(this); - merger.Merge(src, dst); -} - // Add equality constraint Type TypeSolver::Unify(const Type& dst, const Type& src) { - Unifier unifier(this); - return unifier.Unify(dst, src); + // Known limitation + // - handle composite types whose component can be unknown. + // - handle shape pattern matching + TypeNode* lhs = GetTypeNode(dst); + TypeNode* rhs = GetTypeNode(src); + + // do occur check so we don't create self-referencing structure + if (lhs->FindRoot() == rhs->FindRoot()) { + return lhs->resolved_type; + } + if (lhs->resolved_type.as()) { + MergeFromTo(lhs, rhs); + return rhs->resolved_type; + } else if (rhs->resolved_type.as()) { + MergeFromTo(rhs, lhs); + return lhs->resolved_type; + } else { + lhs->parent = rhs; + CHECK(AlphaEqual(lhs->resolved_type, rhs->resolved_type)) + << "Incompatible parent types in UF:" + << lhs->resolved_type << " and " << rhs->resolved_type; + return rhs->resolved_type; + } } // Add type constraint to the solver. @@ -371,9 +96,9 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) { tlink->value = tnode; rnode->type_list.Push(tlink); // insert type->relation node - std::unordered_set singleton { rnode }; - Propagator prop(this, &singleton); - prop.Propagate(tnode->resolved_type); + LinkNode* rlink = arena_.make >(); + rlink->value = rnode; + tnode->rel_list.Push(rlink); } // add the relation to the working queue. this->AddToQueue(rnode); @@ -385,10 +110,12 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) { // Resolve a type in the solver context. Type TypeSolver::Resolve(const Type& type) { - Resolver resolver(this); auto it = tmap_.find(type); - Type t = (it != tmap_.end()) ? it->second->FindRoot()->resolved_type : type; - return resolver.Resolve(t); + if (it != tmap_.end()) { + return it->second->FindRoot()->resolved_type; + } else { + return type; + } } bool TypeSolver::Solve() { @@ -401,7 +128,7 @@ bool TypeSolver::Solve() { // update the relation with given evidence. Array args; for (auto* tlink = rnode->type_list.head; tlink != nullptr; tlink = tlink->next) { - args.push_back(Resolve(tlink->value->FindRoot()->resolved_type)); + args.push_back(tlink->value->FindRoot()->resolved_type); CHECK_LE(args.size(), rel->args.size()); } // call the function @@ -434,8 +161,8 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver") return solver->Solve(); }); } else if (name == "Unify") { - return TypedPackedFunc([solver](Type lhs, Type rhs) { - return solver->Unify(lhs, rhs); + return TypedPackedFunc([solver](Type lhs, Type rhs) { + solver->Unify(lhs, rhs); }); } else if (name == "Resolve") { return TypedPackedFunc([solver](Type t) { diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index b4635fdec331..2f311c9b9810 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -18,7 +18,6 @@ namespace relay { using common::LinkNode; using common::LinkedList; - /*! * \brief Interface of type solver used in type inference. * @@ -66,11 +65,6 @@ class TypeSolver { Type Unify(const Type& lhs, const Type& rhs); private: - class OccursChecker; - class Unifier; - class Resolver; - class Propagator; - class Merger; class Reporter; struct TypeNode; struct RelationNode; @@ -83,15 +77,15 @@ class TypeSolver { * that can unifies the same types to the name resolved_type. * * It also contains collection of links to related Relations, - * which is stored in rel_set. + * which is stored in rel_list. */ struct TypeNode { /*! \brief The final resolved type */ Type resolved_type; /*! \brief type node in the union find algorithm */ TypeNode* parent{nullptr}; - /*! \brief set of relations that is related to this type node */ - std::unordered_set rel_set; + /*! \brief list of relations that is related to this type node */ + LinkedList rel_list; /*! * \brief Find the root type node, perform path compression * \return The root type node. @@ -131,7 +125,7 @@ class TypeSolver { size_t num_resolved_rels_{0}; /*! \brief map from type node to types. */ std::unordered_map tmap_; - /*! \brief Internal queue to update the relation */ + /*! \breif Internal queue to update the relation */ std::queue update_queue_; /*! \brief allocator of all the internal node obhect*/ common::Arena arena_; @@ -169,7 +163,22 @@ class TypeSolver { * \param src The source operand * \param dst The dst operand. */ - void MergeFromTo(TypeNode* src, TypeNode* dst); + void MergeFromTo(TypeNode* src, TypeNode* dst) { + if (src == dst) return; + src->parent = dst; + // move the link to the to dst + for (auto* rlink = src->rel_list.head; rlink != nullptr;) { + // store next pointer first before rlink get moved + auto* next = rlink->next; + // if the relation is not yet resolved + // send the relation to the new + if (!rlink->value->resolved) { + this->AddToQueue(rlink->value); + dst->rel_list.Push(rlink); + } + rlink = next; + } + } }; } // namespace relay diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 403863c1d757..b99d975135be 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -12,211 +12,105 @@ namespace tvm { namespace relay { -template -struct InsertionSet { - std::unordered_set set; - std::vector data; - void Insert(const T& t) { - if (set.count(t) == 0) { - set.insert(t); - data.push_back(t); - } - } -}; - -class TypeVarTVisitor : public TypeVisitor { +// FreeTypeVar +class FreeTypeVarTVisitor : public TypeVisitor { public: - TypeVarTVisitor( - InsertionSet* type_vars, - InsertionSet* bound_type_vars) - : type_vars_(type_vars), bound_type_vars_(bound_type_vars) { } + FreeTypeVarTVisitor( + Array* free_vars, + std::unordered_set* bound_vars) + : free_vars_(free_vars), bound_vars_(bound_vars) { } void VisitType_(const TypeVarNode* tp) final { TypeVar var = GetRef(tp); - type_vars_->Insert(var); + if (bound_vars_->count(var) == 0) { + free_vars_->push_back(var); + } } void VisitType_(const FuncTypeNode* f) final { for (auto type_param : f->type_params) { - type_vars_->Insert(type_param); - bound_type_vars_->Insert(type_param); + bound_vars_->insert(type_param); } TypeVisitor::VisitType_(f); } private: - InsertionSet* type_vars_; - InsertionSet* bound_type_vars_; + Array* free_vars_; + std::unordered_set* bound_vars_; }; -class TypeVarEVisitor : private ExprVisitor { +class FreeTypeVarEVisitor : private ExprVisitor { public: - Array CollectFree() { - Array ret; - for (const auto& v : type_vars_.data) { - if (bound_type_vars_.set.count(v) == 0) { - ret.push_back(v); - } - } - return ret; - } - - Array CollectBound() { - Array ret; - for (const auto& v : bound_type_vars_.data) { - ret.push_back(v); - } - return ret; - } - - Array CollectAll() { - Array ret; - for (const auto& v : type_vars_.data) { - ret.push_back(v); - } - return ret; - } - - Array Free(const Expr& expr) { - VisitExpr(expr); - return CollectFree(); - } - - Array Free(const Type& type) { - VisitType(type); - return CollectFree(); - } - - Array Bound(const Expr& expr) { - VisitExpr(expr); - return CollectBound(); - } - - Array Bound(const Type& type) { - VisitType(type); - return CollectBound(); - } - - Array All(const Expr& expr) { - VisitExpr(expr); - return CollectAll(); + Array Find(const Expr& expr) { + this->VisitExpr(expr); + return free_vars_; } - Array All(const Type& type) { - VisitType(type); - return CollectAll(); + Array Find(const Type& type) { + this->VisitType(type); + return free_vars_; } void VisitExpr_(const FunctionNode* f) final { for (const auto& tp : f->type_params) { - type_vars_.Insert(tp); - bound_type_vars_.Insert(tp); + bound_vars_.insert(tp); } ExprVisitor::VisitExpr_(f); } void VisitType(const Type& t) final { - TypeVarTVisitor(&type_vars_, &bound_type_vars_) + FreeTypeVarTVisitor(&free_vars_, &bound_vars_) .VisitType(t); } private: - InsertionSet type_vars_; - InsertionSet bound_type_vars_; + // The result list + Array free_vars_; + std::unordered_set bound_vars_; }; -class VarVisitor : protected ExprVisitor { +class FreeVarVisitor : protected ExprVisitor { public: - Array Free(const Expr& expr) { + Array Find(const Expr& expr) { this->VisitExpr(expr); - Array ret; - for (const auto& v : vars_.data) { - if (bound_vars_.set.count(v) == 0) { - ret.push_back(v); - } - } - return ret; - } - - Array Bound(const Expr& expr) { - this->VisitExpr(expr); - Array ret; - for (const auto& v : bound_vars_.data) { - ret.push_back(v); - } - return ret; - } - - Array All(const Expr& expr) { - this->VisitExpr(expr); - Array ret; - for (const auto& v : vars_.data) { - ret.push_back(v); - } - return ret; - } - - void MarkBounded(const Var& v) { - bound_vars_.Insert(v); - vars_.Insert(v); + return free_vars_; } void VisitExpr_(const VarNode* var) final { - vars_.Insert(GetRef(var)); + if (bound_vars_.count(var) == 0) { + free_vars_.push_back(GetRef(var)); + } } void VisitExpr_(const FunctionNode* op) final { for (const auto& param : op->params) { - MarkBounded(param); + bound_vars_.insert(param.operator->()); } VisitExpr(op->body); } void VisitExpr_(const LetNode* op) final { - MarkBounded(op->var); + bound_vars_.insert(op->var.operator->()); VisitExpr(op->value); VisitExpr(op->body); } private: - InsertionSet vars_; - InsertionSet bound_vars_; + // The result list + Array free_vars_; + std::unordered_set bound_vars_; }; tvm::Array FreeTypeVars(const Expr& expr) { - return TypeVarEVisitor().Free(expr); + return FreeTypeVarEVisitor().Find(expr); } tvm::Array FreeTypeVars(const Type& type) { - return TypeVarEVisitor().Free(type); -} - -tvm::Array BoundTypeVars(const Expr& expr) { - return TypeVarEVisitor().Bound(expr); -} - -tvm::Array BoundTypeVars(const Type& type) { - return TypeVarEVisitor().Bound(type); -} - -tvm::Array AllTypeVars(const Expr& expr) { - return TypeVarEVisitor().All(expr); -} - -tvm::Array AllTypeVars(const Type& type) { - return TypeVarEVisitor().All(type); + return FreeTypeVarEVisitor().Find(type); } tvm::Array FreeVars(const Expr& expr) { - return VarVisitor().Free(expr); -} - -tvm::Array BoundVars(const Expr& expr) { - return VarVisitor().Bound(expr); -} - -tvm::Array AllVars(const Expr& expr) { - return VarVisitor().All(expr); + return FreeVarVisitor().Find(expr); } TVM_REGISTER_API("relay._ir_pass.free_vars") @@ -224,46 +118,16 @@ TVM_REGISTER_API("relay._ir_pass.free_vars") *ret = FreeVars(args[0]); }); -TVM_REGISTER_API("relay._ir_pass.bound_vars") - .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = BoundVars(args[0]); - }); - -TVM_REGISTER_API("relay._ir_pass.all_vars") - .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = AllVars(args[0]); - }); - TVM_REGISTER_API("relay._ir_pass.free_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; - if (x.as_derived()) { + if (x.as()) { *ret = FreeTypeVars(Downcast(x)); } else { *ret = FreeTypeVars(Downcast(x)); } }); -TVM_REGISTER_API("relay._ir_pass.bound_type_vars") - .set_body([](TVMArgs args, TVMRetValue* ret) { - NodeRef x = args[0]; - if (x.as_derived()) { - *ret = BoundTypeVars(Downcast(x)); - } else { - *ret = BoundTypeVars(Downcast(x)); - } - }); - -TVM_REGISTER_API("relay._ir_pass.all_type_vars") - .set_body([](TVMArgs args, TVMRetValue* ret) { - NodeRef x = args[0]; - if (x.as_derived()) { - *ret = AllTypeVars(Downcast(x)); - } else { - *ret = AllTypeVars(Downcast(x)); - } - }); - /*! * \brief Get reference counter of each internal ExprNode in body. * \param body The body expression. diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index 50aed4c57338..385bde974014 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -6,17 +6,13 @@ TEST(Relay, SelfReference) { using namespace tvm; - auto tensor_type = relay::TensorTypeNode::make({}, ::tvm::Bool()); - auto x = relay::VarNode::make("x", relay::Type()); - auto f = relay::FunctionNode::make(tvm::Array{ x }, x, relay::Type(), {}); - - auto y = relay::VarNode::make("y", tensor_type); - auto call = relay::CallNode::make(f, Array{ y }); - auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, relay::Type(), {}); + auto type_a = relay::TypeVarNode::make("a", relay::TypeVarNode::kType); + auto type_b = relay::TypeVarNode::make("b", relay::TypeVarNode::kType); + auto x = relay::VarNode::make("x", type_a); + auto f = relay::FunctionNode::make(tvm::Array{ x }, x, type_b, Array{}); + auto fx = relay::CallNode::make(f, Array{ x }); auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map{})); - - auto expected = relay::FuncTypeNode::make(tvm::Array{ tensor_type }, tensor_type, {}, {}); - CHECK(AlphaEqual(type_fx->checked_type(), expected)); + CHECK_EQ(type_fx->checked_type(), type_a); } int main(int argc, char ** argv) { diff --git a/tests/python/relay/test_pass_free_vars.py b/tests/python/relay/test_pass_free_vars.py new file mode 100644 index 000000000000..151dbe1412bc --- /dev/null +++ b/tests/python/relay/test_pass_free_vars.py @@ -0,0 +1,41 @@ +import tvm +from tvm import relay +from tvm.relay.ir_pass import free_vars, free_type_vars + +def test_free_vars(): + ty = relay.TensorType([], "int32") + x = relay.Var("x", ty) + fvx = free_vars(x) + assert len(fvx) == 1 + assert fvx[0] == x + v = relay.Constant(tvm.nd.array(10)) + + let = relay.Let(x, v, x) + fvx = free_vars(let) + assert len(free_vars(let)) == 0 + f = relay.Function([x], x, ty) + assert len(free_vars(f)) == 0 + + +def test_tuple(): + t = relay.Var('t') + fv = free_vars(relay.Tuple([t, t])) + assert len(fv) == 1 + assert fv[0] == t + fv = free_vars(relay.TupleGetItem(t, 123)) + assert len(fv) == 1 + assert fv[0] == t + + +def test_free_type_vars(): + tp = relay.TypeVar("") + ty = relay.TupleType([tp, relay.TensorType([], "int32")]) + x = relay.Var("x", ty) + y = relay.Var("y") + let = relay.Let(x, y, x) + fvl = free_vars(let) + assert len(fvl) == 1 + assert fvl[0] == y + ftvl = free_type_vars(let) + assert len(ftvl) == 1 + assert ftvl[0] == tp diff --git a/tests/python/relay/test_pass_vars.py b/tests/python/relay/test_pass_vars.py deleted file mode 100644 index c8d3d6d14992..000000000000 --- a/tests/python/relay/test_pass_vars.py +++ /dev/null @@ -1,144 +0,0 @@ -import tvm -from tvm import relay -from tvm.relay.ir_pass import (free_vars, free_type_vars, - bound_vars, bound_type_vars, - all_vars, all_type_vars) - -def assert_vars_match(actual, expected): - assert len(actual) == len(expected) - for i in range(len(actual)): - assert actual[i] == expected[i] - - -def test_free_vars(): - ty = relay.TensorType([], "int32") - x = relay.Var("x", ty) - fvx = free_vars(x) - assert len(fvx) == 1 - assert fvx[0] == x - v = relay.Constant(tvm.nd.array(10)) - - let = relay.Let(x, v, x) - fvx = free_vars(let) - assert len(free_vars(let)) == 0 - f = relay.Function([x], x, ty) - assert len(free_vars(f)) == 0 - - -def test_free_vars_tuple(): - t = relay.Var('t') - fv = free_vars(relay.Tuple([t, t])) - assert len(fv) == 1 - assert fv[0] == t - fv = free_vars(relay.TupleGetItem(t, 123)) - assert len(fv) == 1 - assert fv[0] == t - - -def test_free_type_vars(): - tp = relay.TypeVar("") - ty = relay.TupleType([tp, relay.TensorType([], "int32")]) - x = relay.Var("x", ty) - y = relay.Var("y") - let = relay.Let(x, y, x) - fvl = free_vars(let) - assert len(fvl) == 1 - assert fvl[0] == y - ftvl = free_type_vars(let) - assert len(ftvl) == 1 - assert ftvl[0] == tp - - -def test_bound_vars(): - x = relay.Var("x") - y = relay.Var("y") - z = relay.Var("z") - a = relay.Var("a") - - f1 = relay.Function([x, y, z], relay.Let(a, x, relay.Tuple([]))) - assert_vars_match(bound_vars(f1), [x, y, z, a]) - - tup = relay.Tuple([x, y, z, a]) - assert len(bound_vars(tup)) == 0 - - f2 = relay.Function([x, y], relay.Tuple([x, y, z, a])) - assert_vars_match(bound_vars(f2), [x, y]) - - -def test_bound_type_vars(): - a = relay.TypeVar("a") - b = relay.TypeVar("b") - c = relay.TypeVar("c") - - ft1 = relay.FuncType([a], b, [a, b]) - bound_ft1 = bound_type_vars(ft1) - assert_vars_match(bound_type_vars(ft1), [a, b]) - - ft2 = relay.FuncType([], c, [a]) - assert_vars_match(bound_type_vars(ft2), [a]) - - tup_ty = relay.TupleType([a, b, c]) - assert len(bound_type_vars(tup_ty)) == 0 - - f1 = relay.Function([], relay.Tuple([]), type_params=[a, b]) - assert_vars_match(bound_type_vars(f1), [a, b]) - - f2 = relay.Function([], relay.Tuple([]), c) - assert len(bound_type_vars(f2)) == 0 - - x = relay.Var("x", a) - let1 = relay.Let(x, relay.Tuple([]), x) - assert len(bound_type_vars(let1)) == 0 - - let2 = relay.Let(x, relay.Function([], relay.Tuple([]), type_params=[b, c]), x) - assert_vars_match(bound_type_vars(let2), [b, c]) - - -def test_all_vars(): - x = relay.Var("x") - y = relay.Var("y") - z = relay.Var("z") - - f1 = relay.Function([x, y], z) - assert_vars_match(all_vars(f1), [x, y, z]) - - f2 = relay.Function([x], relay.Let(y, relay.Tuple([]), z)) - assert_vars_match(all_vars(f2), [x, y, z]) - - f3 = relay.Function([x], relay.Tuple([y, z])) - assert_vars_match(all_vars(f3), [x, y, z]) - - tup = relay.Tuple([x, y, z]) - assert_vars_match(all_vars(tup), [x, y, z]) - - -def test_all_type_vars(): - a = relay.TypeVar("a") - b = relay.TypeVar("b") - c = relay.TypeVar("c") - - ft1 = relay.FuncType([b], c, [a]) - assert_vars_match(all_type_vars(ft1), [a, b, c]) - - ft2 = relay.FuncType([], relay.TupleType([a, b, c]), []) - assert_vars_match(all_type_vars(ft2), [a, b, c]) - - w = relay.Var("w") - x = relay.Var("x", a) - y = relay.Var("y", b) - z = relay.Var("z", c) - - f1 = relay.Function([x], y, b, [a]) - assert_vars_match(all_type_vars(f1), [a, b]) - - f2 = relay.Function([x], relay.Let(y, x, z)) - assert_vars_match(all_type_vars(f2), [a, b, c]) - - f3 = relay.Function([], relay.Tuple([x, y, z]), ret_type=relay.TupleType([a, b, c])) - assert_vars_match(all_type_vars(f3), [a, b, c]) - - f4 = relay.Function([w], relay.Tuple([]), type_params=[a, b, c]) - assert_vars_match(all_type_vars(f4), [a, b, c]) - - f5 = relay.Function([w], w) - assert len(all_type_vars(f5)) == 0 diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index ac4eb1b404db..06cb19639dcf 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -23,7 +23,7 @@ def test_monomorphic_let(): x = sb.let('x', relay.const(1.0, "float64")) sb.ret(x) xchecked = relay.ir_pass.infer_type(sb.get()) - assert xchecked.checked_type == relay.scalar_type("float64" ) + assert xchecked.checked_type == relay.scalar_type("float64") def test_single_op(): @@ -41,15 +41,14 @@ def test_add_broadcast_op(): return x + y; } """ - x = relay.var('x', shape=(10, 4)) - y = relay.var('y', shape=(5, 10, 1)) - z = x + y - func = relay.Function([x, y], z) - t1 = relay.TensorType((10, 4), 'float32') - t2 = relay.TensorType((5, 10, 1), 'float32') - t3 = relay.TensorType((5, 10, 4), 'float32') - expected_ty = relay.FuncType([t1, t2], t3) - assert_has_type(func, expected_ty) + pass + # x = relay.var('x', shape=(10, 4)) + # y = relay.var('y', shape=(5, 10, 1)) + # z = x + y + # func = relay.Function([x, y], z) + # ttype = relay.TensorType((5, 5, 5), 'float32') + # expected_ty = relay.FuncType([ttype, ttype], ttype) + # assert_has_type(func.to_func(), expected_ty) def test_dual_op(): @@ -111,17 +110,24 @@ def f(n: i32, data: f32) -> f32 { assert "%3 = @f(%1, %2)" in mod.astext() assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32) +# This currently fails and should pass under the type system. +# +# This test is to illustrate problem with our weak form of +# unification. +# + def test_incomplete_call(): - tt = relay.scalar_type('int32') - x = relay.var('x', tt) + sb = ScopeBuilder() + x = relay.var('x', dtype='int32') f = relay.var('f') - func = relay.Function([x, f], relay.Call(f, [x]), tt) - - ft = relay.ir_pass.infer_type(func) - f_type = relay.FuncType([tt], tt) - assert ft.checked_type == relay.FuncType([tt, f_type], tt) + func = relay.Function([x, f], relay.Call(f, [x])) + try: + relay.ir_pass.infer_type(func) + assert False + except tvm.TVMError as e: + assert True def test_tuple(): tp = relay.TensorType((10,)) @@ -130,7 +136,6 @@ def test_tuple(): assert (relay.ir_pass.infer_type(res).checked_type == relay.TupleType([tp, tp])) - def test_free_expr(): x = relay.var("x", "float32") y = relay.add(x, x) @@ -156,26 +161,38 @@ def test_type_args(): assert sh2[1].value == 10 -def test_global_var_recursion(): +def test_self_reference(): + """ + Program: + def f(x) { + return x; + } + """ + a = relay.TypeVar("a") + x = relay.var("x", a) + sb = relay.ScopeBuilder() + + f = relay.Function([x], x) + fx = relay.Call(f, [x]) + assert relay.ir_pass.infer_type(x).checked_type == a + assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a) + assert relay.ir_pass.infer_type(fx).checked_type == a + + +def test_global_var_cow_issue(): mod = relay.Module({}) gv = relay.GlobalVar("foo") x = relay.var('x', shape=[]) - tt = relay.scalar_type('float32') - - func = relay.Function([x], relay.Call(gv, [x]), tt) + func = relay.Function([x], relay.Call(gv, [x]), + relay.TensorType([], 'float32')) mod[gv] = func - ft = relay.ir_pass.infer_type(gv, mod) - assert mod[ft].checked_type == relay.FuncType([tt], tt) - def test_equal(): i = relay.var('i', shape=[], dtype='int32') eq = op.equal(i, relay.const(0, dtype='int32')) - func = relay.Function([i], eq) - ft = relay.ir_pass.infer_type(func) - - assert ft.checked_type == relay.FuncType([relay.scalar_type('int32')], relay.scalar_type('bool')) + # This should fail .... + func = relay.Function([i], eq, ret_type=relay.TensorType([], 'int32')) if __name__ == "__main__": @@ -187,12 +204,8 @@ def test_equal(): test_decl() test_recursion() test_tuple() - test_generalized_tuple() test_incomplete_call() - test_generalized_call() - test_call_with_type_args() test_free_expr() test_type_args() test_self_reference() - test_global_var_recursion() - test_equal() + test_global_var_cow_issue() diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index 1e2fed0af1f8..e8ff67756931 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -1,6 +1,5 @@ import tvm from tvm import relay -from nose.tools import raises def make_rel(name, args, num_inputs=None, attrs=None): @@ -49,170 +48,7 @@ def test_backward_solving(): assert solver.Resolve(t3) == relay.ty.TensorType((10, 10, 20), "float32") -def test_unify_tuple(): - solver = make_solver() - t1 = relay.ty.IncompleteType() - t2 = relay.ty.IncompleteType() - t3 = relay.ty.TensorType((10, 20), "float32") - - tup1 = relay.ty.TupleType([t1, t2]) - tup2 = relay.ty.TupleType([t3, t3]) - - unified = solver.Unify(tup1, tup2) - assert unified == tup2 - - -def test_unify_functype(): - solver = make_solver() - t1 = relay.ty.IncompleteType() - t2 = relay.ty.IncompleteType() - t3 = relay.ty.IncompleteType() - - unit = relay.ty.TupleType([]) - tensor1 = relay.ty.TensorType((10, 20), "float32") - tensor2 = relay.ty.TensorType((10,), "float32") - - ft1 = relay.ty.FuncType([t1, t2], t3) - ft2 = relay.ty.FuncType([tensor1, tensor2], unit) - - unified = solver.Unify(ft1, ft2) - assert unified == ft2 - - -def test_recursive_unify(): - solver = make_solver() - t1 = relay.ty.IncompleteType() - t2 = relay.ty.IncompleteType() - t3 = relay.ty.IncompleteType() - - tensor1 = relay.ty.TensorType((10, 10, 20), "float32") - tensor2 = relay.ty.TensorType((10, 20), "float32") - tensor3 = relay.ty.TensorType((10,), "float32") - - tup1 = relay.ty.TupleType([relay.ty.TupleType([t1, t2]), t2]) - tup2 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor2]) - - ft1 = relay.ty.FuncType([tup1, t3], t3) - ft2 = relay.ty.FuncType([tup2, tensor3], tensor3) - - unified = solver.Unify(ft1, ft2) - assert unified == ft2 - - -def test_unify_vars_under_tuples(): - solver = make_solver() - t1 = relay.ty.IncompleteType() - - tup1 = relay.ty.TupleType([t1, t1]) - unified = solver.Unify(tup1, tup1) - assert unified == tup1 - - t2 = relay.ty.IncompleteType() - tup2 = relay.ty.TupleType([t2, t2]) - - tup3 = relay.ty.TupleType([t1, t2]) - tup4 = relay.ty.TupleType([t2, t1]) - unified = solver.Unify(tup3, tup4) - assert (unified == tup1 or unified == tup2) - - -def test_binding_over_typevars(): - solver = make_solver() - - t1 = relay.ty.IncompleteType() - t2 = relay.ty.IncompleteType() - - a = relay.ty.TypeVar('a') - b = relay.ty.TypeVar('b') - c = relay.ty.TypeVar('c') - d = relay.ty.TypeVar('d') - - ft1 = relay.ty.FuncType([t1], t2, [c, d]) - ft2 = relay.ty.FuncType([a], b, [a, b]) - unified = solver.Unify(ft1, ft2) - assert (unified == solver.Resolve(ft1)) - - -def test_recursive_backward_solving(): - solver = make_solver() - - tensor1 = relay.ty.TensorType((10, 20), "float32") - tensor2 = relay.ty.TensorType((10, 1, 1), "float32") - tensor3 = relay.ty.TensorType((10,), "float32") - - t1 = relay.ty.IncompleteType() - t2 = relay.ty.IncompleteType() - t3 = relay.ty.IncompleteType() - - tup1 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor3]) - tup2 = relay.ty.TupleType([relay.ty.TupleType([t1, t2]), t3]) - solver.gen_type("Identity", [tup1], out=tup2) - - assert solver.Solve() - assert solver.Resolve(tup2) == tup1 - - -def test_backward_solving_after_child_update(): - solver = make_solver() - - tensor1 = relay.ty.TensorType((10, 20), "float32") - tensor2 = relay.ty.TensorType((10, 1, 1), "float32") - - t1 = relay.ty.IncompleteType() - t2 = relay.ty.IncompleteType() - t3 = relay.ty.IncompleteType() - - tup1 = relay.ty.TupleType([t1, t2]) - tup2 = relay.ty.TupleType([t1, t3]) - - tup_concrete = relay.ty.TupleType([tensor1, tensor2]) - - t4 = solver.gen_type("Identity", [tup1]) - t5 = solver.gen_type("Identity", [tup2]) - - solver.gen_type("Identity", [t4], out=t5) - assert solver.Solve() - assert solver.Resolve(t3) == t3 or solver.Resolve(t3) == t2 - assert solver.Resolve(t4) == tup1 or solver.Resolve(t4) == tup2 - assert solver.Resolve(t5) == tup1 or solver.Resolve(t5) == tup2 - - # updating the variables *inside* tup1 and tup2 should update t4 and t5 - solver.gen_type("Identity", [t1], out=tensor1) - solver.gen_type("Identity", [t2], out=tensor2) - assert solver.Solve() - assert solver.Resolve(t4) == tup_concrete - assert solver.Resolve(t5) == tup_concrete - -@raises(tvm._ffi.base.TVMError) -def test_incompatible_tuple_unification(): - solver = make_solver() - t1 = relay.ty.IncompleteType() - t2 = relay.ty.IncompleteType() - - tensor1 = relay.ty.TensorType((1, 2, 3), "float32") - tensor2 = relay.ty.TensorType((2, 3), "float32") - tensor3 = relay.ty.TensorType((3,), "float32") - - tup1 = relay.ty.TupleType([relay.ty.TupleType([t1, t1]), t2]) - tup2 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor3]) - solver.Unify(tup1, tup2) - - -@raises(tvm._ffi.base.TVMError) -def test_bad_recursive_unification(): - solver = make_solver() - t1 = relay.ty.IncompleteType() - solver.Unify(t1, relay.ty.TupleType([t1, t1])) - if __name__ == "__main__": test_bcast() test_backward_solving() - test_unify_tuple() - test_unify_functype() - test_recursive_unify() - test_unify_vars_under_tuples() - test_recursive_backward_solving() - test_backward_solving_after_child_update() - test_incompatible_tuple_unification() - test_bad_recursive_unification()