From 862825a4e4ad02d6f82dbb6913f2ca657e1bf3c6 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 28 Nov 2018 18:49:01 -0800 Subject: [PATCH 01/55] Expand unification in type solver --- src/relay/pass/type_solver.cc | 148 ++++++++++++++++++++++++++++------ src/relay/pass/type_solver.h | 23 ++++-- 2 files changed, 141 insertions(+), 30 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index e1efcbbdd0b9..88c8be973e94 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -5,6 +5,7 @@ */ #include #include "type_solver.h" +#include "../ir/type_functor.h" namespace tvm { namespace relay { @@ -38,6 +39,128 @@ class TypeSolver::Reporter : public TypeReporterNode { TypeSolver* solver_; }; +class TypeSolver::Unifier : public TypeFunctor { + public: + explicit Unifier(TypeSolver* solver) : solver_(solver) {} + + Type Unify(const Type& src, const Type& dst) { + // Known limitation + // - handle shape pattern matching + TypeNode* lhs = solver_->GetTypeNode(dst); + TypeNode* rhs = solver_->GetTypeNode(src); + + // do occur check so we don't create self-referencing structure + if (lhs->FindRoot() == rhs->FindRoot()) { + return lhs->resolved_type; + } + if (lhs->resolved_type.as()) { + solver_->MergeFromTo(lhs, rhs); + return rhs->resolved_type; + } else if (rhs->resolved_type.as()) { + solver_->MergeFromTo(rhs, lhs); + return lhs->resolved_type; + } else { + Type resolved = + this->VisitType(lhs->resolved_type, rhs->resolved_type); + CHECK(resolved.defined()) + << "Unable to unify parent types: " + << lhs->resolved_type << " and " << rhs->resolved_type; + rhs->resolved_type = resolved; + solver_->MergeFromTo(lhs, rhs); + return resolved; + } + } + + // child type needs to be listed in parent's relations, even though + // the child is not an argument to the relations (still have to + // update the relations if the child changes) + void RegisterChildType(const Type& parent, const Type& child) { + TypeNode* parent_node = solver_->GetTypeNode(parent); + TypeNode* child_node = solver_->GetTypeNode(child); + solver_->TransferQueue(parent_node, child_node); + } + + // default: unify only if alpha-equal + Type VisitTypeDefault_(const Node* op, const Type& tn) override { + NodeRef nr = GetRef(op); + Type t1 = GetRef(nr.as_derived()); + if (!AlphaEqual(t1, tn)) { + return Type(nullptr); + } + return t1; + } + + Type VisitType_(const TupleTypeNode* op, const Type& tn) override { + const auto* ttn = tn.as(); + if (!ttn || op->fields.size() != ttn->fields.size()) { + return Type(nullptr); + } + + TupleType tt1 = GetRef(op); + TupleType tt2 = GetRef(ttn); + + std::vector new_fields; + for (size_t i = 0; i < tt1->fields.size(); i++) { + RegisterChildType(tt1, tt1->fields[i]); + RegisterChildType(tt2, tt1->fields[i]); + new_fields.push_back(Unify(tt1->fields[i], tt2->fields[i])); + } + return TupleTypeNode::make(new_fields); + } + + Type VisitType_(const FuncTypeNode* op, const Type& tn) override { + const auto* ftn = tn.as(); + if (!ftn + || op->arg_types.size() != ftn->arg_types.size() + || op->type_params.size() != ftn->type_params.size() + || op->type_constraints.size() != ftn->type_constraints.size()) { + return Type(nullptr); + } + + FuncType ft1 = GetRef(op); + FuncType ft2 = GetRef(ftn); + + RegisterChildType(ft1, ft1->ret_type); + RegisterChildType(ft2, ft2->ret_type); + Type ret_type = Unify(ft1->ret_type, ft2->ret_type); + + std::vector arg_types; + for (size_t i = 0; i < ft1->arg_types.size(); i++) { + RegisterChildType(ft1, ft1->arg_types[i]); + RegisterChildType(ft2, ft2->arg_types[i]); + arg_types.push_back(Unify(ft1->arg_types[i], ft2->arg_types[i])); + } + + std::vector type_params; + for (size_t i = 0; i < ft1->type_params.size(); i++) { + RegisterChildType(ft1, ft1->type_params[i]); + RegisterChildType(ft2, ft2->type_params[i]); + Type unified_var = Unify(ft1->type_params[i], ft2->type_params[i]); + const auto* tvn = unified_var.as(); + CHECK(tvn) << "Two type vars unified into a non type var? " + << ft1->type_params[i] << " and " << ft2->type_params[i]; + type_params.push_back(GetRef(tvn)); + } + + std::vector type_constraints; + for (size_t i = 0; i < ft1->type_constraints.size(); i++) { + RegisterChildType(ft1, ft1->type_constraints[i]); + RegisterChildType(ft2, ft2->type_constraints[i]); + Type unified_constraint = Unify(ft1->type_constraints[i], + ft2->type_constraints[i]); + const auto* tcn = unified_constraint.as(); + CHECK(tcn) << "Two type constraints unified into a non-constraint?" + << ft1->type_constraints[i] << " and " << ft2->type_constraints[i]; + type_constraints.push_back(GetRef(tcn)); + } + + return FuncTypeNode::make(arg_types, ret_type, type_params, type_constraints); + } + + private: + TypeSolver* solver_; +}; + // constructor TypeSolver::TypeSolver() : reporter_(make_node(this)) { @@ -56,29 +179,8 @@ TypeSolver::~TypeSolver() { // Add equality constraint Type TypeSolver::Unify(const Type& dst, const Type& src) { - // Known limitation - // - handle composite types whose component can be unknown. - // - handle shape pattern matching - TypeNode* lhs = GetTypeNode(dst); - TypeNode* rhs = GetTypeNode(src); - - // do occur check so we don't create self-referencing structure - if (lhs->FindRoot() == rhs->FindRoot()) { - return lhs->resolved_type; - } - if (lhs->resolved_type.as()) { - MergeFromTo(lhs, rhs); - return rhs->resolved_type; - } else if (rhs->resolved_type.as()) { - MergeFromTo(rhs, lhs); - return lhs->resolved_type; - } else { - lhs->parent = rhs; - CHECK(AlphaEqual(lhs->resolved_type, rhs->resolved_type)) - << "Incompatible parent types in UF:" - << lhs->resolved_type << " and " << rhs->resolved_type; - return rhs->resolved_type; - } + Unifier unifier(this); + return unifier.Unify(dst, src); } // Add type constraint to the solver. diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index 2f311c9b9810..4c3b9e17c835 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -65,6 +65,7 @@ class TypeSolver { Type Unify(const Type& lhs, const Type& rhs); private: + class Unifier; class Reporter; struct TypeNode; struct RelationNode; @@ -159,19 +160,16 @@ class TypeSolver { update_queue_.push(rel); } /*! - * \brief Merge rhs type node to lhs + * \brief Adds relations in relation queue of src to dst * \param src The source operand - * \param dst The dst operand. + * \param dst The dst operand */ - void MergeFromTo(TypeNode* src, TypeNode* dst) { - if (src == dst) return; - src->parent = dst; - // move the link to the to dst + void TransferQueue(TypeNode* src, TypeNode* dst) { for (auto* rlink = src->rel_list.head; rlink != nullptr;) { // store next pointer first before rlink get moved auto* next = rlink->next; // if the relation is not yet resolved - // send the relation to the new + // send the relation to the queue if (!rlink->value->resolved) { this->AddToQueue(rlink->value); dst->rel_list.Push(rlink); @@ -179,6 +177,17 @@ class TypeSolver { rlink = next; } } + /*! + * \brief Merge rhs type node to lhs + * \param src The source operand + * \param dst The dst operand. + */ + void MergeFromTo(TypeNode* src, TypeNode* dst) { + if (src == dst) return; + src->parent = dst; + // move the link to the to dst + TransferQueue(src, dst); + } }; } // namespace relay From c60e4c817a6368c938c3d40247455d93c53a3437 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 28 Nov 2018 22:22:10 -0800 Subject: [PATCH 02/55] Only register child type relations for incomplete child types --- src/relay/pass/type_solver.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 88c8be973e94..c2f1cdd44082 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -65,8 +65,6 @@ class TypeSolver::Unifier : public TypeFunctor { CHECK(resolved.defined()) << "Unable to unify parent types: " << lhs->resolved_type << " and " << rhs->resolved_type; - rhs->resolved_type = resolved; - solver_->MergeFromTo(lhs, rhs); return resolved; } } @@ -77,6 +75,11 @@ class TypeSolver::Unifier : public TypeFunctor { void RegisterChildType(const Type& parent, const Type& child) { TypeNode* parent_node = solver_->GetTypeNode(parent); TypeNode* child_node = solver_->GetTypeNode(child); + + // if child is already concrete type, nothing to do + if (!child_node->resolved_type.as()) { + return; + } solver_->TransferQueue(parent_node, child_node); } @@ -102,7 +105,7 @@ class TypeSolver::Unifier : public TypeFunctor { std::vector new_fields; for (size_t i = 0; i < tt1->fields.size(); i++) { RegisterChildType(tt1, tt1->fields[i]); - RegisterChildType(tt2, tt1->fields[i]); + RegisterChildType(tt2, tt2->fields[i]); new_fields.push_back(Unify(tt1->fields[i], tt2->fields[i])); } return TupleTypeNode::make(new_fields); From 318d85fc7d1b35080e86602893c02b67352dfeda Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 28 Nov 2018 22:25:30 -0800 Subject: [PATCH 03/55] Removed redundant registrations --- src/relay/pass/type_solver.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index c2f1cdd44082..782c4d638582 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -136,8 +136,6 @@ class TypeSolver::Unifier : public TypeFunctor { std::vector type_params; for (size_t i = 0; i < ft1->type_params.size(); i++) { - RegisterChildType(ft1, ft1->type_params[i]); - RegisterChildType(ft2, ft2->type_params[i]); Type unified_var = Unify(ft1->type_params[i], ft2->type_params[i]); const auto* tvn = unified_var.as(); CHECK(tvn) << "Two type vars unified into a non type var? " @@ -147,8 +145,6 @@ class TypeSolver::Unifier : public TypeFunctor { std::vector type_constraints; for (size_t i = 0; i < ft1->type_constraints.size(); i++) { - RegisterChildType(ft1, ft1->type_constraints[i]); - RegisterChildType(ft2, ft2->type_constraints[i]); Type unified_constraint = Unify(ft1->type_constraints[i], ft2->type_constraints[i]); const auto* tcn = unified_constraint.as(); From 5155426a866484eaabf642d3e3a7c1c126651caa Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 29 Nov 2018 16:15:17 -0800 Subject: [PATCH 04/55] Be sure to copy linked list nodes for child types --- src/relay/pass/type_solver.cc | 37 +++++++++++++++++++++++------------ src/relay/pass/type_solver.h | 20 ++++++------------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 782c4d638582..8ec90c776011 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -65,6 +65,9 @@ class TypeSolver::Unifier : public TypeFunctor { CHECK(resolved.defined()) << "Unable to unify parent types: " << lhs->resolved_type << " and " << rhs->resolved_type; + TypeNode* top = solver_->GetTypeNode(resolved); + solver_->MergeFromTo(lhs, top); + solver_->MergeFromTo(rhs, top); return resolved; } } @@ -76,11 +79,19 @@ class TypeSolver::Unifier : public TypeFunctor { TypeNode* parent_node = solver_->GetTypeNode(parent); TypeNode* child_node = solver_->GetTypeNode(child); - // if child is already concrete type, nothing to do - if (!child_node->resolved_type.as()) { - return; + // allocate copies to avoid making a circular linked list + for (auto* rlink = parent_node->rel_list.head; rlink != nullptr;) { + auto* next = rlink->next; + auto* value = rlink->value; + if (!value->resolved) { + solver_->AddToQueue(value); + } + auto* copy = solver_->arena_.make >(); + copy->value = value; + child_node->rel_list.Push(copy); + + rlink = next; } - solver_->TransferQueue(parent_node, child_node); } // default: unify only if alpha-equal @@ -104,9 +115,10 @@ class TypeSolver::Unifier : public TypeFunctor { std::vector new_fields; for (size_t i = 0; i < tt1->fields.size(); i++) { - RegisterChildType(tt1, tt1->fields[i]); - RegisterChildType(tt2, tt2->fields[i]); - new_fields.push_back(Unify(tt1->fields[i], tt2->fields[i])); + Type field = Unify(tt1->fields[i], tt2->fields[i]); + RegisterChildType(tt1, field); + RegisterChildType(tt2, field); + new_fields.push_back(field); } return TupleTypeNode::make(new_fields); } @@ -123,15 +135,16 @@ class TypeSolver::Unifier : public TypeFunctor { FuncType ft1 = GetRef(op); FuncType ft2 = GetRef(ftn); - RegisterChildType(ft1, ft1->ret_type); - RegisterChildType(ft2, ft2->ret_type); Type ret_type = Unify(ft1->ret_type, ft2->ret_type); + RegisterChildType(ft1, ret_type); + RegisterChildType(ft2, ret_type); std::vector arg_types; for (size_t i = 0; i < ft1->arg_types.size(); i++) { - RegisterChildType(ft1, ft1->arg_types[i]); - RegisterChildType(ft2, ft2->arg_types[i]); - arg_types.push_back(Unify(ft1->arg_types[i], ft2->arg_types[i])); + Type arg_type = Unify(ft1->arg_types[i], ft2->arg_types[i]); + RegisterChildType(ft1, arg_type); + RegisterChildType(ft2, arg_type); + arg_types.push_back(arg_type); } std::vector type_params; diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index 4c3b9e17c835..327910103cdc 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -160,11 +160,14 @@ class TypeSolver { update_queue_.push(rel); } /*! - * \brief Adds relations in relation queue of src to dst + * \brief Merge rhs type node to lhs * \param src The source operand - * \param dst The dst operand + * \param dst The dst operand. */ - void TransferQueue(TypeNode* src, TypeNode* dst) { + void MergeFromTo(TypeNode* src, TypeNode* dst) { + if (src == dst) return; + src->parent = dst; + // move the link to the to dst for (auto* rlink = src->rel_list.head; rlink != nullptr;) { // store next pointer first before rlink get moved auto* next = rlink->next; @@ -177,17 +180,6 @@ class TypeSolver { rlink = next; } } - /*! - * \brief Merge rhs type node to lhs - * \param src The source operand - * \param dst The dst operand. - */ - void MergeFromTo(TypeNode* src, TypeNode* dst) { - if (src == dst) return; - src->parent = dst; - // move the link to the to dst - TransferQueue(src, dst); - } }; } // namespace relay From 94c5d97673434a83de7df466d92aa4248b78399c Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 29 Nov 2018 16:45:59 -0800 Subject: [PATCH 05/55] Add unifier tests --- src/relay/pass/type_solver.cc | 4 +- tests/python/relay/test_type_solver.py | 70 ++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 8ec90c776011..37427152c3c4 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -275,8 +275,8 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver") return solver->Solve(); }); } else if (name == "Unify") { - return TypedPackedFunc([solver](Type lhs, Type rhs) { - solver->Unify(lhs, rhs); + return TypedPackedFunc([solver](Type lhs, Type rhs) { + return solver->Unify(lhs, rhs); }); } else if (name == "Resolve") { return TypedPackedFunc([solver](Type t) { diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index e8ff67756931..5f520687b5e9 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -48,7 +48,77 @@ def test_backward_solving(): assert solver.Resolve(t3) == relay.ty.TensorType((10, 10, 20), "float32") +def test_unify_tuple(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + t3 = relay.ty.TensorType((10, 20), "float32") + + tup1 = relay.ty.TupleType([t1, t2]) + tup2 = relay.ty.TupleType([t3, t3]) + + unified = solver.Unify(tup1, tup2) + assert unified == tup2 + + +def test_unify_functype(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + t3 = relay.ty.IncompleteType() + + unit = relay.ty.TupleType([]) + tensor1 = relay.ty.TensorType((10, 20), "float32") + tensor2 = relay.ty.TensorType((10,), "float32") + + ft1 = relay.ty.FuncType([t1, t2], t3) + ft2 = relay.ty.FuncType([tensor1, tensor2], unit) + + unified = solver.Unify(ft1, ft2) + assert unified == ft2 + + +def test_recursive_unify(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + + tensor1 = relay.ty.TensorType((10, 10, 20), "float32") + tensor2 = relay.ty.TensorType((10, 20), "float32") + + tup1 = relay.ty.TupleType([relay.ty.TupleType([t1, t2]), t2]) + tup2 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor2]) + + ft1 = relay.ty.FuncType([tup1, tensor2], tensor2) + ft2 = relay.ty.FuncType([tup2, tensor2], tensor2) + + unified = solver.Unify(ft1, ft2) + assert unified == ft2 + + +def test_recursive_backward_solving(): + solver = make_solver() + + tensor1 = relay.ty.TensorType((10, 20), "float32") + tensor2 = relay.ty.TensorType((10, 1, 1), "float32") + tensor3 = relay.ty.TensorType((10,), "float32") + + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + t3 = relay.ty.IncompleteType() + + tup1 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor3]) + tup2 = relay.ty.TupleType([relay.ty.TupleType([t1, t2]), t3]) + solver.gen_type("Identity", [tup1], out=tup2) + + assert solver.Solve() + assert solver.Resolve(tup2) == tup1 + if __name__ == "__main__": test_bcast() test_backward_solving() + test_unify_tuple() + test_unify_functype() + test_recursive_unify() + test_recursive_backward_solving() From 2bad1af69b03d39c639ea7f45fc78f6cc391f1c2 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 29 Nov 2018 17:14:08 -0800 Subject: [PATCH 06/55] Add a negative test case --- tests/python/relay/test_type_solver.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index 5f520687b5e9..37d242427ba9 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -1,5 +1,6 @@ import tvm from tvm import relay +from nose.tools import raises def make_rel(name, args, num_inputs=None, attrs=None): @@ -115,6 +116,21 @@ def test_recursive_backward_solving(): assert solver.Resolve(tup2) == tup1 +@raises(tvm._ffi.base.TVMError) +def test_incompatible_tuple_unification(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + + tensor1 = relay.ty.TensorType((1, 2, 3), "float32") + tensor2 = relay.ty.TensorType((2, 3), "float32") + tensor3 = relay.ty.TensorType((3,), "float32") + + tup1 = relay.ty.TupleType([relay.ty.TupleType([t1, t1]), t2]) + tup2 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor3]) + solver.Unify(tup1, tup2) + + if __name__ == "__main__": test_bcast() test_backward_solving() @@ -122,3 +138,4 @@ def test_recursive_backward_solving(): test_unify_functype() test_recursive_unify() test_recursive_backward_solving() + test_incompatible_tuple_unification() From 0175c5353cdadf833d04fdc3d201b7b480a3e764 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 29 Nov 2018 17:41:55 -0800 Subject: [PATCH 07/55] Check for recursive equalities when unifying --- src/relay/pass/type_solver.cc | 35 ++++++++++++++++++++++++++ src/relay/pass/type_solver.h | 1 + tests/python/relay/test_type_solver.py | 26 +++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 37427152c3c4..3d559ff4e26e 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -39,6 +39,28 @@ class TypeSolver::Reporter : public TypeReporterNode { TypeSolver* solver_; }; +class TypeSolver::RecurrenceChecker : public TypeVisitor { + public: + explicit RecurrenceChecker(TypeSolver* solver, TypeNode* var) + : solver_(solver), var_(var), found_(false) {} + + bool Check(const Type& t) { + VisitType(t); + return found_; + } + + void VisitType_(const IncompleteTypeNode* op) override { + IncompleteType t = GetRef(op); + TypeNode* node = solver_->GetTypeNode(t); + found_ = found_ || (var_->FindRoot() == node->FindRoot()); + } + + private: + TypeSolver *solver_; + TypeNode *var_; + bool found_; +}; + class TypeSolver::Unifier : public TypeFunctor { public: explicit Unifier(TypeSolver* solver) : solver_(solver) {} @@ -54,9 +76,15 @@ class TypeSolver::Unifier : public TypeFunctor { return lhs->resolved_type; } if (lhs->resolved_type.as()) { + CHECK(!CheckRecurrence(lhs, rhs->resolved_type)) + << "Incomplete type " << lhs << " occurs in " + << rhs->resolved_type << ", cannot unify"; solver_->MergeFromTo(lhs, rhs); return rhs->resolved_type; } else if (rhs->resolved_type.as()) { + CHECK(!CheckRecurrence(rhs, lhs->resolved_type)) + << "Incomplete type " << rhs << " occurs in " + << lhs->resolved_type << ", cannot unify"; solver_->MergeFromTo(rhs, lhs); return lhs->resolved_type; } else { @@ -94,6 +122,13 @@ class TypeSolver::Unifier : public TypeFunctor { } } + // Checks whether lhs (taken to be a type var) appears in t, meaning + // there is a recursive equality constraint, which should be rejected. + bool CheckRecurrence(TypeNode *lhs, const Type &t) { + RecurrenceChecker rc(solver_, lhs); + return rc.Check(t); + } + // default: unify only if alpha-equal Type VisitTypeDefault_(const Node* op, const Type& tn) override { NodeRef nr = GetRef(op); diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index 327910103cdc..b62d1f076f10 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -65,6 +65,7 @@ class TypeSolver { Type Unify(const Type& lhs, const Type& rhs); private: + class RecurrenceChecker; class Unifier; class Reporter; struct TypeNode; diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index 37d242427ba9..237804ce4483 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -97,6 +97,23 @@ def test_recursive_unify(): assert unified == ft2 +def test_unify_vars_under_tuples(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + + tup1 = relay.ty.TupleType([t1, t1]) + unified = solver.Unify(tup1, tup1) + assert unified == tup1 + + t2 = relay.ty.IncompleteType() + tup2 = relay.ty.TupleType([t2, t2]) + + tup3 = relay.ty.TupleType([t1, t2]) + tup4 = relay.ty.TupleType([t2, t1]) + unified = solver.Unify(tup3, tup4) + assert (unified == tup1 or unified == tup2) + + def test_recursive_backward_solving(): solver = make_solver() @@ -131,11 +148,20 @@ def test_incompatible_tuple_unification(): solver.Unify(tup1, tup2) +@raises(tvm._ffi.base.TVMError) +def test_bad_recursive_unification(): + solver = make_solver() + t1 = relay.ty.IncompleteType() + solver.Unify(t1, relay.ty.TupleType([t1, t1])) + + if __name__ == "__main__": test_bcast() test_backward_solving() test_unify_tuple() test_unify_functype() test_recursive_unify() + test_unify_vars_under_tuples() test_recursive_backward_solving() test_incompatible_tuple_unification() + test_bad_recursive_unification() From 6f5db7c31c70f8a4a5d8e643315ab7e700958b36 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 29 Nov 2018 20:07:15 -0800 Subject: [PATCH 08/55] Minor tweaks to error messages --- src/relay/pass/type_solver.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 3d559ff4e26e..cbe5d93705cf 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -77,19 +77,18 @@ class TypeSolver::Unifier : public TypeFunctor { } if (lhs->resolved_type.as()) { CHECK(!CheckRecurrence(lhs, rhs->resolved_type)) - << "Incomplete type " << lhs << " occurs in " + << "Incomplete type " << lhs->resolved_type << " occurs in " << rhs->resolved_type << ", cannot unify"; solver_->MergeFromTo(lhs, rhs); return rhs->resolved_type; } else if (rhs->resolved_type.as()) { CHECK(!CheckRecurrence(rhs, lhs->resolved_type)) - << "Incomplete type " << rhs << " occurs in " + << "Incomplete type " << rhs->resolved_type << " occurs in " << lhs->resolved_type << ", cannot unify"; solver_->MergeFromTo(rhs, lhs); return lhs->resolved_type; } else { - Type resolved = - this->VisitType(lhs->resolved_type, rhs->resolved_type); + Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type); CHECK(resolved.defined()) << "Unable to unify parent types: " << lhs->resolved_type << " and " << rhs->resolved_type; From 16de70d6fcff2460d8e224aed8820fbbb3c6243b Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 29 Nov 2018 20:09:39 -0800 Subject: [PATCH 09/55] Improve recursive unification test --- tests/python/relay/test_type_solver.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index 237804ce4483..01c68e5f37b5 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -83,15 +83,17 @@ def test_recursive_unify(): solver = make_solver() t1 = relay.ty.IncompleteType() t2 = relay.ty.IncompleteType() + t3 = relay.ty.IncompleteType() tensor1 = relay.ty.TensorType((10, 10, 20), "float32") tensor2 = relay.ty.TensorType((10, 20), "float32") + tensor3 = relay.ty.TensorType((10,), "float32") tup1 = relay.ty.TupleType([relay.ty.TupleType([t1, t2]), t2]) tup2 = relay.ty.TupleType([relay.ty.TupleType([tensor1, tensor2]), tensor2]) - ft1 = relay.ty.FuncType([tup1, tensor2], tensor2) - ft2 = relay.ty.FuncType([tup2, tensor2], tensor2) + ft1 = relay.ty.FuncType([tup1, t3], t3) + ft2 = relay.ty.FuncType([tup2, tensor3], tensor3) unified = solver.Unify(ft1, ft2) assert unified == ft2 From 8829afef1141999844fb14915aee08cd41bfb947 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 29 Nov 2018 20:13:38 -0800 Subject: [PATCH 10/55] Do not copy relation list nodes if already resolved --- src/relay/pass/type_solver.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index cbe5d93705cf..c0bac11b1781 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -106,16 +106,17 @@ class TypeSolver::Unifier : public TypeFunctor { TypeNode* parent_node = solver_->GetTypeNode(parent); TypeNode* child_node = solver_->GetTypeNode(child); - // allocate copies to avoid making a circular linked list + // allocate copies to avoid introducing circular link for (auto* rlink = parent_node->rel_list.head; rlink != nullptr;) { auto* next = rlink->next; auto* value = rlink->value; if (!value->resolved) { solver_->AddToQueue(value); + + auto* copy = solver_->arena_.make >(); + copy->value = value; + child_node->rel_list.Push(copy); } - auto* copy = solver_->arena_.make >(); - copy->value = value; - child_node->rel_list.Push(copy); rlink = next; } From 5027d76b627844132765ac7574bf2ffdd64528c6 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 30 Nov 2018 01:04:36 -0800 Subject: [PATCH 11/55] Add visitor for type resolution, have more complicated unification test case --- src/relay/pass/type_solver.cc | 30 ++++++++++++++++++----- src/relay/pass/type_solver.h | 1 + tests/python/relay/test_type_solver.py | 33 ++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 6 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index c0bac11b1781..29a57b5ec355 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -56,8 +56,8 @@ class TypeSolver::RecurrenceChecker : public TypeVisitor { } private: - TypeSolver *solver_; - TypeNode *var_; + TypeSolver* solver_; + TypeNode* var_; bool found_; }; @@ -112,7 +112,6 @@ class TypeSolver::Unifier : public TypeFunctor { auto* value = rlink->value; if (!value->resolved) { solver_->AddToQueue(value); - auto* copy = solver_->arena_.make >(); copy->value = value; child_node->rel_list.Push(copy); @@ -208,6 +207,23 @@ class TypeSolver::Unifier : public TypeFunctor { TypeSolver* solver_; }; +class TypeSolver::Resolver : public TypeMutator { + public: + explicit Resolver(TypeSolver* solver) : solver_(solver) {} + + Type Resolve(const Type& t) { + return VisitType(t); + } + + Type VisitType_(const IncompleteTypeNode* op) override { + auto* node = solver_->GetTypeNode(GetRef(op)); + return node->resolved_type; + } + + private: + TypeSolver* solver_; +}; + // constructor TypeSolver::TypeSolver() : reporter_(make_node(this)) { @@ -260,11 +276,13 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) { // Resolve a type in the solver context. Type TypeSolver::Resolve(const Type& type) { auto it = tmap_.find(type); + Type t = type; if (it != tmap_.end()) { - return it->second->FindRoot()->resolved_type; - } else { - return type; + t = it->second->FindRoot()->resolved_type; } + + Resolver resolver(this); + return resolver.Resolve(t); } bool TypeSolver::Solve() { diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index b62d1f076f10..a16a3b26bb29 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -67,6 +67,7 @@ class TypeSolver { private: class RecurrenceChecker; class Unifier; + class Resolver; class Reporter; struct TypeNode; struct RelationNode; diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index 01c68e5f37b5..1e545f8dd6aa 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -135,6 +135,38 @@ def test_recursive_backward_solving(): assert solver.Resolve(tup2) == tup1 +def test_backward_solving_after_child_update(): + solver = make_solver() + + tensor1 = relay.ty.TensorType((10, 20), "float32") + tensor2 = relay.ty.TensorType((10, 1, 1), "float32") + + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + t3 = relay.ty.IncompleteType() + + tup1 = relay.ty.TupleType([t1, t2]) + tup2 = relay.ty.TupleType([t1, t3]) + + tup_concrete = relay.ty.TupleType([tensor1, tensor2]) + + t4 = solver.gen_type("Identity", [tup1]) + t5 = solver.gen_type("Identity", [tup2]) + + solver.gen_type("Identity", [t4], out=t5) + assert solver.Solve() + assert solver.Resolve(t3) == t3 or solver.Resolve(t3) == t2 + assert solver.Resolve(t4) == tup1 or solver.Resolve(t4) == tup2 + assert solver.Resolve(t5) == tup1 or solver.Resolve(t5) == tup2 + + # updating the variables *inside* tup1 and tup2 should update t4 and t5 + solver.gen_type("Identity", [t1], out=tensor1) + solver.gen_type("Identity", [t2], out=tensor2) + assert solver.Solve() + assert solver.Resolve(t4) == tup_concrete + assert solver.Resolve(t5) == tup_concrete + + @raises(tvm._ffi.base.TVMError) def test_incompatible_tuple_unification(): solver = make_solver() @@ -165,5 +197,6 @@ def test_bad_recursive_unification(): test_recursive_unify() test_unify_vars_under_tuples() test_recursive_backward_solving() + test_backward_solving_after_child_update() test_incompatible_tuple_unification() test_bad_recursive_unification() From ce59f5be8221b49e9a018641fec442118e2eef02 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 30 Nov 2018 01:11:08 -0800 Subject: [PATCH 12/55] Avoid catastrophic failure in Resolve() by only recursing in one branch (not sure why this fails) --- src/relay/pass/type_solver.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 29a57b5ec355..7e641db371d6 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -275,14 +275,13 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) { // Resolve a type in the solver context. Type TypeSolver::Resolve(const Type& type) { + Resolver resolver(this); auto it = tmap_.find(type); - Type t = type; if (it != tmap_.end()) { - t = it->second->FindRoot()->resolved_type; + return resolver.Resolve(it->second->FindRoot()->resolved_type); + } else { + return type; } - - Resolver resolver(this); - return resolver.Resolve(t); } bool TypeSolver::Solve() { From 4d1bde6ab26017bba84a307807f23def452eb976 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 30 Nov 2018 01:43:46 -0800 Subject: [PATCH 13/55] Add a null check before resolution --- src/relay/pass/type_solver.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 7e641db371d6..0c7acba1403e 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -280,7 +280,10 @@ Type TypeSolver::Resolve(const Type& type) { if (it != tmap_.end()) { return resolver.Resolve(it->second->FindRoot()->resolved_type); } else { - return type; + if (!type.defined()) { + return type; + } + return resolver.Resolve(type); } } From c44bb7328ffe0ce9cb51101e6b8aca72456f34aa Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 30 Nov 2018 01:47:36 -0800 Subject: [PATCH 14/55] Move null check into resolution visitor --- src/relay/pass/type_solver.cc | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 0c7acba1403e..373642977dff 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -212,6 +212,9 @@ class TypeSolver::Resolver : public TypeMutator { explicit Resolver(TypeSolver* solver) : solver_(solver) {} Type Resolve(const Type& t) { + if (!t.defined()) { + return t; + } return VisitType(t); } @@ -277,14 +280,8 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) { Type TypeSolver::Resolve(const Type& type) { Resolver resolver(this); auto it = tmap_.find(type); - if (it != tmap_.end()) { - return resolver.Resolve(it->second->FindRoot()->resolved_type); - } else { - if (!type.defined()) { - return type; - } - return resolver.Resolve(type); - } + Type t = (it != tmap_.end()) ? it->second->FindRoot()->resolved_type : type; + return resolver.Resolve(t); } bool TypeSolver::Solve() { From 2b18e21ec1f188d4795ff5340701e8f301a64e98 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Sat, 1 Dec 2018 18:24:58 -0800 Subject: [PATCH 15/55] Rename RecurrenceChecker to OccursChecker --- src/relay/pass/type_solver.cc | 16 +++++++++------- src/relay/pass/type_solver.h | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 373642977dff..2ebec141593a 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -39,9 +39,9 @@ class TypeSolver::Reporter : public TypeReporterNode { TypeSolver* solver_; }; -class TypeSolver::RecurrenceChecker : public TypeVisitor { +class TypeSolver::OccursChecker : public TypeVisitor { public: - explicit RecurrenceChecker(TypeSolver* solver, TypeNode* var) + explicit OccursChecker(TypeSolver* solver, TypeNode* var) : solver_(solver), var_(var), found_(false) {} bool Check(const Type& t) { @@ -76,13 +76,13 @@ class TypeSolver::Unifier : public TypeFunctor { return lhs->resolved_type; } if (lhs->resolved_type.as()) { - CHECK(!CheckRecurrence(lhs, rhs->resolved_type)) + CHECK(!CheckOccurs(lhs, rhs->resolved_type)) << "Incomplete type " << lhs->resolved_type << " occurs in " << rhs->resolved_type << ", cannot unify"; solver_->MergeFromTo(lhs, rhs); return rhs->resolved_type; } else if (rhs->resolved_type.as()) { - CHECK(!CheckRecurrence(rhs, lhs->resolved_type)) + CHECK(!CheckOccurs(rhs, lhs->resolved_type)) << "Incomplete type " << rhs->resolved_type << " occurs in " << lhs->resolved_type << ", cannot unify"; solver_->MergeFromTo(rhs, lhs); @@ -121,10 +121,12 @@ class TypeSolver::Unifier : public TypeFunctor { } } - // Checks whether lhs (taken to be a type var) appears in t, meaning + // Checks whether lhs (taken to be a type var) occurs in t, meaning // there is a recursive equality constraint, which should be rejected. - bool CheckRecurrence(TypeNode *lhs, const Type &t) { - RecurrenceChecker rc(solver_, lhs); + // N.b.: A tautology like ?a = ?a is okay and should be checked for + // *before* calling this method + bool CheckOccurs(TypeNode *lhs, const Type &t) { + OccursChecker rc(solver_, lhs); return rc.Check(t); } diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index a16a3b26bb29..fc0a057df1dc 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -65,7 +65,7 @@ class TypeSolver { Type Unify(const Type& lhs, const Type& rhs); private: - class RecurrenceChecker; + class OccursChecker; class Unifier; class Resolver; class Reporter; From beb2bb5e109f5326fc2a9238ac78dd35efde3b0d Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Sun, 2 Dec 2018 21:00:39 -0800 Subject: [PATCH 16/55] Recursively propagate type relations to child types when a constraint is added, not during unification. Resolve before running a relation --- src/relay/pass/type_solver.cc | 92 +++++++++++++++++++++++------------ src/relay/pass/type_solver.h | 1 + 2 files changed, 61 insertions(+), 32 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 2ebec141593a..60bd4a70bd51 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -99,28 +99,6 @@ class TypeSolver::Unifier : public TypeFunctor { } } - // child type needs to be listed in parent's relations, even though - // the child is not an argument to the relations (still have to - // update the relations if the child changes) - void RegisterChildType(const Type& parent, const Type& child) { - TypeNode* parent_node = solver_->GetTypeNode(parent); - TypeNode* child_node = solver_->GetTypeNode(child); - - // allocate copies to avoid introducing circular link - for (auto* rlink = parent_node->rel_list.head; rlink != nullptr;) { - auto* next = rlink->next; - auto* value = rlink->value; - if (!value->resolved) { - solver_->AddToQueue(value); - auto* copy = solver_->arena_.make >(); - copy->value = value; - child_node->rel_list.Push(copy); - } - - rlink = next; - } - } - // Checks whether lhs (taken to be a type var) occurs in t, meaning // there is a recursive equality constraint, which should be rejected. // N.b.: A tautology like ?a = ?a is okay and should be checked for @@ -152,8 +130,6 @@ class TypeSolver::Unifier : public TypeFunctor { std::vector new_fields; for (size_t i = 0; i < tt1->fields.size(); i++) { Type field = Unify(tt1->fields[i], tt2->fields[i]); - RegisterChildType(tt1, field); - RegisterChildType(tt2, field); new_fields.push_back(field); } return TupleTypeNode::make(new_fields); @@ -172,14 +148,10 @@ class TypeSolver::Unifier : public TypeFunctor { FuncType ft2 = GetRef(ftn); Type ret_type = Unify(ft1->ret_type, ft2->ret_type); - RegisterChildType(ft1, ret_type); - RegisterChildType(ft2, ret_type); std::vector arg_types; for (size_t i = 0; i < ft1->arg_types.size(); i++) { Type arg_type = Unify(ft1->arg_types[i], ft2->arg_types[i]); - RegisterChildType(ft1, arg_type); - RegisterChildType(ft2, arg_type); arg_types.push_back(arg_type); } @@ -229,6 +201,63 @@ class TypeSolver::Resolver : public TypeMutator { TypeSolver* solver_; }; +// It ends up being more compact to simply have TypeFunctor { + public: + explicit Propagator(TypeSolver* solver, RelationNode* rel) : solver_(solver), rel_(rel) {} + + // adds the relation node to t and all child types of t + void Propagate(const Type& t) { + VisitType(t); + } + + void AddRelToList(const Type& t) { + TypeNode* tnode = solver_->GetTypeNode(t); + LinkNode* rlink = solver_->arena_.make >(); + rlink->value = rel_; + tnode->rel_list.Push(rlink); + } + + void VisitTypeDefault_(const Node* op) override { + NodeRef nr = GetRef(op); + Type t = GetRef(nr.as_derived()); + AddRelToList(t); + } + + void VisitType_(const TupleTypeNode* op) override { + TupleType tt = GetRef(op); + AddRelToList(tt); + + for (const Type& t : tt->fields) { + Propagate(t); + } + } + + void VisitType_(const FuncTypeNode* op) override { + FuncType ft = GetRef(op); + AddRelToList(ft); + + Propagate(ft->ret_type); + for (auto arg_type : ft->arg_types) { + Propagate(arg_type); + } + + for (auto type_param : ft->type_params) { + Propagate(type_param); + } + + for (auto type_cs : ft->type_constraints) { + Propagate(type_cs); + } + } + + private: + TypeSolver* solver_; + RelationNode* rel_; +}; + // constructor TypeSolver::TypeSolver() : reporter_(make_node(this)) { @@ -266,9 +295,8 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) { tlink->value = tnode; rnode->type_list.Push(tlink); // insert type->relation node - LinkNode* rlink = arena_.make >(); - rlink->value = rnode; - tnode->rel_list.Push(rlink); + Propagator prop(this, rnode); + prop.Propagate(tnode->resolved_type); } // add the relation to the working queue. this->AddToQueue(rnode); @@ -296,7 +324,7 @@ bool TypeSolver::Solve() { // update the relation with given evidence. Array args; for (auto* tlink = rnode->type_list.head; tlink != nullptr; tlink = tlink->next) { - args.push_back(tlink->value->FindRoot()->resolved_type); + args.push_back(Resolve(tlink->value->FindRoot()->resolved_type)); CHECK_LE(args.size(), rel->args.size()); } // call the function diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index fc0a057df1dc..1b35718cd473 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -68,6 +68,7 @@ class TypeSolver { class OccursChecker; class Unifier; class Resolver; + class Propagator; class Reporter; struct TypeNode; struct RelationNode; From 36ae3f89728a136462cbe351eacc46050e97e113 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 6 Dec 2018 20:31:50 -0800 Subject: [PATCH 17/55] Make use of new unification in type inference, add generalization, fix broken tests --- src/relay/pass/type_infer.cc | 123 +++++++++++++++++--------- tests/python/relay/test_type_infer.py | 113 +++++++++++++++++------ 2 files changed, 164 insertions(+), 72 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index ee1b5ab10148..22f2122c3018 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -56,31 +56,11 @@ bool TupleGetItemRel(const Array& types, return true; } -bool MakeTupleRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(static_cast(num_inputs + 1), types.size()); - for (int i = 0; i < num_inputs; ++i) { - if (types[i].as()) return false; - } - Array fields; - for (int i = 0; i < num_inputs; ++i) { - fields.push_back(types[i]); - } - reporter->Assign(types[num_inputs], TupleTypeNode::make(fields)); - return true; -} - TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs); TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem") .set_body_typed&, int, const Attrs&, const TypeReporter&)>( TupleGetItemRel); -TVM_REGISTER_API("tvm.relay.type_relation.MakeTuple") -.set_body_typed&, int, const Attrs&, const TypeReporter&)>( - MakeTupleRel); - struct ResolvedTypeInfo { explicit ResolvedTypeInfo(Type checked_type, Array type_args) : checked_type(checked_type), type_args(type_args) {} @@ -92,6 +72,41 @@ struct ResolvedTypeInfo { Array type_args = Array(NodePtr(nullptr)); }; +// Converts incomplete types remaining in function signature to type vars +class Generalizer : public TypeMutator { + public: + Generalizer() : subst_map_({}), varno_(0) {} + + // turns each distinct incomplete type into a type var and returns + // the transformed type with an array of all type vars present + Type Generalize(const Type &t, Array* vars) { + vars_ = vars; + return VisitType(t); + } + + Type VisitType_(const IncompleteTypeNode *op) override { + IncompleteType t = GetRef(op); + auto it = subst_map_.find(t); + if (it != subst_map_.end()) { + return (*it).second; + } + + // generate a new type var, add to list + std::stringstream ss; + ss << "_var_" << varno_; + varno_++; + TypeVar new_var = TypeVarNode::make(ss.str(), TypeVarNode::Kind::kType); + vars_->push_back(new_var); + subst_map_.Set(t, new_var); + return new_var; + } + + private: + tvm::Map subst_map_; + Array* vars_; + int varno_; +}; + // // The inference algorithm can roughly be devided into three stages: // - Populate the constraints by visiting the expression (TypeInferencer.GetType) @@ -175,19 +190,11 @@ class TypeInferencer : private ExprFunctor { } Type VisitExpr_(const TupleNode* op) final { - if (!make_tuple_rel_.defined()) { - make_tuple_rel_ = TypeRelationFn( - EnvFunc::Get("tvm.relay.type_relation.MakeTuple").node_); - } Array types; for (Expr field : op->fields) { types.push_back(GetType(field)); } - Type rtype = IncompleteTypeNode::make(TypeVarNode::Kind::kType); - types.push_back(rtype); - solver_.AddConstraint(TypeRelationNode::make( - make_tuple_rel_, types, op->fields.size(), Attrs())); - return rtype; + return TupleTypeNode::make(types); } Type VisitExpr_(const TupleGetItemNode* op) final { @@ -253,14 +260,18 @@ class TypeInferencer : private ExprFunctor { } // instantiate the function type with fresh - FuncType Instantiate(const FuncTypeNode* fn_ty, Array* ty_args) { + FuncType Instantiate(const FuncTypeNode* fn_ty, Array* ty_args, const Span& span) { tvm::Map subst_map; // Build a subsitituion map up from the function type and type arguments. // Eventually allow the type vars to be passed in. - for (auto ty_param : fn_ty->type_params) { + for (size_t i = 0; i < fn_ty->type_params.size(); i++) { + auto ty_param = fn_ty->type_params[i]; IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind); subst_map.Set(ty_param, fresh); + if (i < ty_args->size()) { + this->Unify(fresh, (*ty_args)[i], span); + } ty_args->push_back(fresh); } @@ -296,13 +307,23 @@ class TypeInferencer : private ExprFunctor { Type GeneralCall(const CallNode* call, Array arg_types) { Type ftype = GetType(call->op); auto* fn_ty_node = ftype.as(); + auto* inc_ty_node = ftype.as(); + + CHECK(fn_ty_node != nullptr || inc_ty_node != nullptr) + << "only expressions with function types can be called, found " + << ftype << " at " << call->span; + + // incomplete type => it must be a function taking the arg types + // with an unknown return type + if (inc_ty_node != nullptr) { + Type ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + Type func_type = FuncTypeNode::make(arg_types, ret_type, {}, {}); + Type unified = this->Unify(ftype, func_type, call->span); + fn_ty_node = unified.as(); + } - CHECK(fn_ty_node != nullptr) - << "only expressions with function types can be called, found " - << ftype << " at " << call->span; - - Array type_args; - FuncType fn_ty = Instantiate(fn_ty_node, &type_args); + Array type_args = call->type_args; + FuncType fn_ty = Instantiate(fn_ty_node, &type_args, call->span); AddTypeArgs(GetRef(call), type_args); @@ -357,22 +378,36 @@ class TypeInferencer : private ExprFunctor { GetType(param); } Type rtype = GetType(f->body); + if (f->ret_type.defined()) { + rtype = this->Unify(f->ret_type, rtype, f->span); + } + // Run solver using the currently known information solver_.Solve(); // Trying to resolve Array arg_types; + Array type_params = f->type_params; + Generalizer gen; + for (size_t i = 0; i < f->params.size(); ++i) { Type atype = solver_.Resolve(GetType(f->params[i])); - CHECK(atype.as() == nullptr) - << "Cannot resolve type of " << i - << "-th parameter of function at" << f->span; + // CHECK(atype.as() == nullptr) + // << "Cannot resolve type of " << i + // << "-th parameter of function at" << f->span; + Type gen_atype = gen.Generalize(atype, &type_params); + atype = this->Unify(atype, gen_atype, f->span); arg_types.push_back(atype); } + rtype = solver_.Resolve(rtype); - CHECK(rtype.as() == nullptr) - << "Cannot resolve return type of function at" << f->span; + Type gen_rtype = gen.Generalize(rtype, &type_params); + this->Unify(rtype, gen_rtype, f->span); + rtype = gen_rtype; + + // CHECK(rtype.as() == nullptr) + // << "Cannot resolve return type of function at" << f->span; // do not support constraint lifting for now. - return FuncTypeNode::make(arg_types, rtype, f->type_params, {}); + return FuncTypeNode::make(arg_types, rtype, type_params, {}); } }; @@ -380,7 +415,7 @@ class TypeInferencer::Resolver : public ExprMutator { public: Resolver(const std::unordered_map& tmap, TypeSolver* solver) - : tmap_(tmap), solver_(solver) { + : tmap_(tmap), solver_(solver) { } Expr VisitExpr_(const VarNode* op) final { diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 06cb19639dcf..ff5fcfd0aec0 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -23,7 +23,7 @@ def test_monomorphic_let(): x = sb.let('x', relay.const(1.0, "float64")) sb.ret(x) xchecked = relay.ir_pass.infer_type(sb.get()) - assert xchecked.checked_type == relay.scalar_type("float64") + assert xchecked.checked_type == relay.scalar_type("float64" ) def test_single_op(): @@ -41,14 +41,15 @@ def test_add_broadcast_op(): return x + y; } """ - pass - # x = relay.var('x', shape=(10, 4)) - # y = relay.var('y', shape=(5, 10, 1)) - # z = x + y - # func = relay.Function([x, y], z) - # ttype = relay.TensorType((5, 5, 5), 'float32') - # expected_ty = relay.FuncType([ttype, ttype], ttype) - # assert_has_type(func.to_func(), expected_ty) + x = relay.var('x', shape=(10, 4)) + y = relay.var('y', shape=(5, 10, 1)) + z = x + y + func = relay.Function([x, y], z) + t1 = relay.TensorType((10, 4), 'float32') + t2 = relay.TensorType((5, 10, 1), 'float32') + t3 = relay.TensorType((5, 10, 4), 'float32') + expected_ty = relay.FuncType([t1, t2], t3) + assert_has_type(func, expected_ty) def test_dual_op(): @@ -110,24 +111,52 @@ def f(n: i32, data: f32) -> f32 { assert "%3 = @f(%1, %2)" in mod.astext() assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32) -# This currently fails and should pass under the type system. -# -# This test is to illustrate problem with our weak form of -# unification. -# - def test_incomplete_call(): - sb = ScopeBuilder() - x = relay.var('x', dtype='int32') + tt = relay.scalar_type('int32') + x = relay.var('x', tt) + f = relay.var('f') + func = relay.Function([x, f], relay.Call(f, [x]), tt) + + ft = relay.ir_pass.infer_type(func) + f_type = relay.FuncType([tt], tt) + assert ft.checked_type == relay.FuncType([tt, f_type], tt) + + +def test_call_with_type_args(): + a = relay.TypeVar('a') + b = relay.TypeVar('b') + + x = relay.Var('x', a) + f = relay.Var('f', relay.FuncType([a], b)) + func = relay.Function([x, f], relay.Call(f, [x]), b, [a, b]) + + unit_type = relay.TupleType([]) + v = relay.Var('v', unit_type) + concrete_func = relay.Function( + [], + relay.Call( + func, + [relay.Tuple([]), + relay.Function([v], relay.Tuple([]))], + type_args=[unit_type, unit_type]), + unit_type) + + ft = relay.ir_pass.infer_type(concrete_func) + assert ft.checked_type == relay.FuncType([], unit_type) + + +def test_generalized_call(): + x = relay.var('x') f = relay.var('f') func = relay.Function([x, f], relay.Call(f, [x])) - try: - relay.ir_pass.infer_type(func) - assert False - except tvm.TVMError as e: - assert True + a = relay.TypeVar('a') + b = relay.TypeVar('b') + + ft = relay.ir_pass.infer_type(func) + assert ft.checked_type == relay.FuncType([a, relay.FuncType([a], b)], b, [a, b]) + def test_tuple(): tp = relay.TensorType((10,)) @@ -136,6 +165,24 @@ def test_tuple(): assert (relay.ir_pass.infer_type(res).checked_type == relay.TupleType([tp, tp])) + +def test_generalized_tuple(): + x = relay.var('x') + y = relay.var('y') + z = relay.var('z') + + func = relay.Function([x, y, z], relay.Tuple([x, y, z])) + + a = relay.TypeVar('a') + b = relay.TypeVar('b') + c = relay.TypeVar('c') + ft = relay.ir_pass.infer_type(func) + assert ft.checked_type == relay.FuncType( + [a, b, c], + relay.TupleType([a, b, c]), + [a, b, c]) + + def test_free_expr(): x = relay.var("x", "float32") y = relay.add(x, x) @@ -179,20 +226,26 @@ def f(x) { assert relay.ir_pass.infer_type(fx).checked_type == a -def test_global_var_cow_issue(): +def test_global_var_recursion(): mod = relay.Module({}) gv = relay.GlobalVar("foo") x = relay.var('x', shape=[]) - func = relay.Function([x], relay.Call(gv, [x]), - relay.TensorType([], 'float32')) + tt = relay.scalar_type('float32') + + func = relay.Function([x], relay.Call(gv, [x]), tt) mod[gv] = func + ft = relay.ir_pass.infer_type(gv, mod) + assert mod[ft].checked_type == relay.FuncType([tt], tt) + def test_equal(): i = relay.var('i', shape=[], dtype='int32') eq = op.equal(i, relay.const(0, dtype='int32')) - # This should fail .... - func = relay.Function([i], eq, ret_type=relay.TensorType([], 'int32')) + func = relay.Function([i], eq) + ft = relay.ir_pass.infer_type(func) + + assert ft.checked_type == relay.FuncType([relay.scalar_type('int32')], relay.scalar_type('bool')) if __name__ == "__main__": @@ -204,8 +257,12 @@ def test_equal(): test_decl() test_recursion() test_tuple() + test_generalized_tuple() test_incomplete_call() + test_generalized_call() + test_call_with_type_args() test_free_expr() test_type_args() test_self_reference() - test_global_var_cow_issue() + test_global_var_recursion() + test_equal() From 5f355761ee033e853f22e9f614f5d399f4477b9b Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 7 Dec 2018 14:53:19 -0800 Subject: [PATCH 18/55] Intermediate progress on fixing the unifier (this breaks stuff) --- include/tvm/relay/type.h | 34 ++++++++++ src/relay/pass/type_infer.cc | 89 ++++++++++++++++++++------- src/relay/pass/type_solver.cc | 69 ++++++++++++++------- src/relay/pass/type_solver.h | 3 + tests/python/relay/test_type_infer.py | 3 +- 5 files changed, 154 insertions(+), 44 deletions(-) diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 69a8a4fb0bd7..ad62d1d78571 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -317,6 +317,40 @@ class TypeReporter : public NodeRef { using ContainerType = TypeReporterNode; }; +class TypeUnifier; +/*! + * \brief Data structure that handles type unification. + */ +class TypeUnifierNode : public Node { + public: + /*! + * \brief Returns a unified type based on the two arguments. + */ + TVM_DLL virtual Type Unify(const Type& dst, const Type& src) = 0; + + // unifier is not serializable. + void VisitAttrs(tvm::AttrVisitor* v) final {} + + static constexpr const char* _type_key = "relay.TypeUnifier"; + TVM_DECLARE_NODE_TYPE_INFO(TypeUnifierNode, Node); +}; + +/*! + * \brief Container class of TypeUnifier. + * \sa TypeUnifierNode + */ +class TypeUnifier : public NodeRef { + public: + TypeUnifier() {} + explicit TypeUnifier(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) { + } + TypeUnifierNode* operator->() const { + return static_cast(node_.get()); + } + using ContainerType = TypeUnifierNode; +}; + + /*! * \brief User defined type constraint function. * diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 22f2122c3018..640c0c8c0414 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -374,40 +374,85 @@ class TypeInferencer : private ExprFunctor { } Type VisitExpr_(const FunctionNode* f) final { + // first get blank f type + // unify with candidate + // generalize + Array incomplete_arg_types; for (auto param : f->params) { - GetType(param); + incomplete_arg_types.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType)); + } + FuncType incompleteFuncType = + FuncTypeNode::make(incomplete_arg_types, + IncompleteTypeNode::make(TypeVarNode::Kind::kType), + {}, {}); + + Array candidate_arg_types; + for (auto param : f->params) { + candidate_arg_types.push_back(GetType(param)); } Type rtype = GetType(f->body); if (f->ret_type.defined()) { rtype = this->Unify(f->ret_type, rtype, f->span); } + FuncType candidateFuncType = FuncTypeNode::make(candidate_arg_types, + rtype, + f->type_params, {}); - // Run solver using the currently known information solver_.Solve(); - // Trying to resolve - Array arg_types; - Array type_params = f->type_params; - Generalizer gen; + auto unified = + GetRef(Unify(incompleteFuncType, + candidateFuncType, f->span) + .as()); - for (size_t i = 0; i < f->params.size(); ++i) { - Type atype = solver_.Resolve(GetType(f->params[i])); - // CHECK(atype.as() == nullptr) - // << "Cannot resolve type of " << i - // << "-th parameter of function at" << f->span; - Type gen_atype = gen.Generalize(atype, &type_params); - atype = this->Unify(atype, gen_atype, f->span); + // generalize remaining incomplete types + Generalizer gen; + Array arg_types; + Array type_params; + for (auto param : unified->type_params) { + Type gen_param = gen.Generalize(param, &type_params); + Type atype = Unify(param, gen_param, f->span); arg_types.push_back(atype); } - rtype = solver_.Resolve(rtype); - Type gen_rtype = gen.Generalize(rtype, &type_params); - this->Unify(rtype, gen_rtype, f->span); - rtype = gen_rtype; - - // CHECK(rtype.as() == nullptr) - // << "Cannot resolve return type of function at" << f->span; - // do not support constraint lifting for now. - return FuncTypeNode::make(arg_types, rtype, type_params, {}); + Type gen_ret = gen.Generalize(unified->ret_type, &type_params); + Type final_ret = Unify(gen_ret, unified->ret_type, f->span); + + return FuncTypeNode::make(arg_types, final_ret, type_params, {}); + + // for (auto param : f->params) { + // GetType(param); + // } + // Type rtype = GetType(f->body); + // if (f->ret_type.defined()) { + // rtype = this->Unify(f->ret_type, rtype, f->span); + // } + + // // Run solver using the currently known information + // solver_.Solve(); + // // Trying to resolve + // Array arg_types; + // Array type_params = f->type_params; + // Generalizer gen; + + // for (size_t i = 0; i < f->params.size(); ++i) { + // Type atype = solver_.Resolve(GetType(f->params[i])); + // // CHECK(atype.as() == nullptr) + // // << "Cannot resolve type of " << i + // // << "-th parameter of function at" << f->span; + // Type gen_atype = gen.Generalize(atype, &type_params); + // atype = this->Unify(atype, gen_atype, f->span); + // arg_types.push_back(atype); + // } + + // rtype = solver_.Resolve(rtype); + // Type gen_rtype = gen.Generalize(rtype, &type_params); + // this->Unify(rtype, gen_rtype, f->span); + // rtype = gen_rtype; + + // // CHECK(rtype.as() == nullptr) + // // << "Cannot resolve return type of function at" << f->span; + // // do not support constraint lifting for now. + // return FuncTypeNode::make(arg_types, rtype, type_params, {}); } }; diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 60bd4a70bd51..be97fe3e66b3 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -61,15 +61,18 @@ class TypeSolver::OccursChecker : public TypeVisitor { bool found_; }; -class TypeSolver::Unifier : public TypeFunctor { +class TypeSolver::Unifier : + public TypeUnifierNode, public TypeFunctor { public: - explicit Unifier(TypeSolver* solver) : solver_(solver) {} + explicit Unifier(TypeSolver* solver) : solver_(solver), tv_map_({}) {} - Type Unify(const Type& src, const Type& dst) { + Type Unify(const Type& src, const Type& dst) final { // Known limitation // - handle shape pattern matching - TypeNode* lhs = solver_->GetTypeNode(dst); - TypeNode* rhs = solver_->GetTypeNode(src); + Type new_src = InstantiateTypeVar(src); + Type new_dst = InstantiateTypeVar(dst); + TypeNode* lhs = solver_->GetTypeNode(new_dst); + TypeNode* rhs = solver_->GetTypeNode(new_src); // do occur check so we don't create self-referencing structure if (lhs->FindRoot() == rhs->FindRoot()) { @@ -108,6 +111,40 @@ class TypeSolver::Unifier : public TypeFunctor { return rc.Check(t); } + // if t is a type var, replace with an incomplete type + Type InstantiateTypeVar(const Type& t) { + auto* tvn = t.as(); + if (tvn == nullptr) { + return t; + } + + TypeVar tv = GetRef(tvn); + auto it = tv_map_.find(tv); + if (tv_map_.find(tv) != tv_map_.end()) { + return (*it).second; + } + + IncompleteType hole = IncompleteTypeNode::make(tv->kind); + tv_map_.Set(tv, hole); + return hole; + } + + // instantiate away all type parameters in a function type + FuncType InstantiateFuncType(const FuncType& ft) { + Map subst_map; + for (auto type_param : ft->type_params) { + Type hole = InstantiateTypeVar(type_param); + subst_map.Set(type_param, hole); + } + + Type transformed = Bind(ft, subst_map); + auto* new_ft = transformed.as(); + // drop the type param list altogether + return FuncTypeNode::make(new_ft->arg_types, + new_ft->ret_type, {}, + new_ft->type_constraints); + } + // default: unify only if alpha-equal Type VisitTypeDefault_(const Node* op, const Type& tn) override { NodeRef nr = GetRef(op); @@ -139,13 +176,12 @@ class TypeSolver::Unifier : public TypeFunctor { const auto* ftn = tn.as(); if (!ftn || op->arg_types.size() != ftn->arg_types.size() - || op->type_params.size() != ftn->type_params.size() || op->type_constraints.size() != ftn->type_constraints.size()) { return Type(nullptr); } - FuncType ft1 = GetRef(op); - FuncType ft2 = GetRef(ftn); + FuncType ft1 = InstantiateFuncType(GetRef(op)); + FuncType ft2 = InstantiateFuncType(GetRef(ftn)); Type ret_type = Unify(ft1->ret_type, ft2->ret_type); @@ -155,15 +191,6 @@ class TypeSolver::Unifier : public TypeFunctor { arg_types.push_back(arg_type); } - std::vector type_params; - for (size_t i = 0; i < ft1->type_params.size(); i++) { - Type unified_var = Unify(ft1->type_params[i], ft2->type_params[i]); - const auto* tvn = unified_var.as(); - CHECK(tvn) << "Two type vars unified into a non type var? " - << ft1->type_params[i] << " and " << ft2->type_params[i]; - type_params.push_back(GetRef(tvn)); - } - std::vector type_constraints; for (size_t i = 0; i < ft1->type_constraints.size(); i++) { Type unified_constraint = Unify(ft1->type_constraints[i], @@ -174,11 +201,12 @@ class TypeSolver::Unifier : public TypeFunctor { type_constraints.push_back(GetRef(tcn)); } - return FuncTypeNode::make(arg_types, ret_type, type_params, type_constraints); + return FuncTypeNode::make(arg_types, ret_type, {}, type_constraints); } private: TypeSolver* solver_; + Map tv_map_; }; class TypeSolver::Resolver : public TypeMutator { @@ -260,7 +288,7 @@ class TypeSolver::Propagator : public TypeFunctor { // constructor TypeSolver::TypeSolver() - : reporter_(make_node(this)) { + : reporter_(make_node(this)), unifier_(make_node(this)) { } // destructor @@ -276,8 +304,7 @@ TypeSolver::~TypeSolver() { // Add equality constraint Type TypeSolver::Unify(const Type& dst, const Type& src) { - Unifier unifier(this); - return unifier.Unify(dst, src); + return unifier_->Unify(dst, src); } // Add type constraint to the solver. diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index 1b35718cd473..c6d5c20d737d 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -18,6 +18,7 @@ namespace relay { using common::LinkNode; using common::LinkedList; + /*! * \brief Interface of type solver used in type inference. * @@ -135,6 +136,8 @@ class TypeSolver { common::Arena arena_; /*! \brief Reporter that reports back to self */ TypeReporter reporter_; + /*! \brief Data structure for unifying types */ + TypeUnifier unifier_; /*! * \brief GetTypeNode that is corresponds to t. * if it do not exist, create a new one. diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index ff5fcfd0aec0..6d1cd134a6eb 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -216,10 +216,11 @@ def f(x) { } """ a = relay.TypeVar("a") + b = relay.TypeVar("b") x = relay.var("x", a) sb = relay.ScopeBuilder() - f = relay.Function([x], x) + f = relay.Function([x], x, b, [a, b]) fx = relay.Call(f, [x]) assert relay.ir_pass.infer_type(x).checked_type == a assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a) From db247adaa872f1bfc610cf901b0d4d8eca4eaa99 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 7 Dec 2018 15:29:04 -0800 Subject: [PATCH 19/55] Generalize only after all other type unification (still broken) --- src/relay/pass/type_infer.cc | 68 +++++++++++++++++++++-------------- src/relay/pass/type_solver.cc | 6 ++-- 2 files changed, 44 insertions(+), 30 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 640c0c8c0414..80935d12dfa7 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -75,13 +75,20 @@ struct ResolvedTypeInfo { // Converts incomplete types remaining in function signature to type vars class Generalizer : public TypeMutator { public: - Generalizer() : subst_map_({}), varno_(0) {} + Generalizer() : subst_map_({}), vars_({}), varno_(0) {} // turns each distinct incomplete type into a type var and returns // the transformed type with an array of all type vars present - Type Generalize(const Type &t, Array* vars) { - vars_ = vars; - return VisitType(t); + Type Generalize(const Type &t) { + Type ret = VisitType(t); + + auto* ftn = ret.as(); + if (ftn == nullptr) { + return ret; + } + + // for a func type, we add the type vars to the list at top + return FuncTypeNode::make(ftn->arg_types, ftn->ret_type, vars_, ftn->type_constraints); } Type VisitType_(const IncompleteTypeNode *op) override { @@ -96,14 +103,26 @@ class Generalizer : public TypeMutator { ss << "_var_" << varno_; varno_++; TypeVar new_var = TypeVarNode::make(ss.str(), TypeVarNode::Kind::kType); - vars_->push_back(new_var); + vars_.push_back(new_var); subst_map_.Set(t, new_var); return new_var; } + Type VisitType_(const FuncTypeNode *op) override { + // drop type params, only do it at the top level + Array arg_types; + for (auto arg_type : op->arg_types) { + arg_types.push_back(this->VisitType(arg_type)); + } + + Type ret_type = this->VisitType(op->ret_type); + + return FuncTypeNode::make(arg_types, ret_type, {}, op->type_constraints); + } + private: tvm::Map subst_map_; - Array* vars_; + Array vars_; int varno_; }; @@ -374,9 +393,7 @@ class TypeInferencer : private ExprFunctor { } Type VisitExpr_(const FunctionNode* f) final { - // first get blank f type - // unify with candidate - // generalize + solver_.Solve(); Array incomplete_arg_types; for (auto param : f->params) { incomplete_arg_types.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType)); @@ -398,26 +415,22 @@ class TypeInferencer : private ExprFunctor { rtype, f->type_params, {}); - solver_.Solve(); - auto unified = - GetRef(Unify(incompleteFuncType, - candidateFuncType, f->span) - .as()); + return Unify(incompleteFuncType, candidateFuncType, f->span); - // generalize remaining incomplete types - Generalizer gen; - Array arg_types; - Array type_params; - for (auto param : unified->type_params) { - Type gen_param = gen.Generalize(param, &type_params); - Type atype = Unify(param, gen_param, f->span); - arg_types.push_back(atype); - } + // // generalize remaining incomplete types + // Generalizer gen; + // Array arg_types; + // Array type_params; + // for (auto param : unified->type_params) { + // Type gen_param = gen.Generalize(param, &type_params); + // Type atype = Unify(param, gen_param, f->span); + // arg_types.push_back(atype); + // } - Type gen_ret = gen.Generalize(unified->ret_type, &type_params); - Type final_ret = Unify(gen_ret, unified->ret_type, f->span); + // Type gen_ret = gen.Generalize(unified->ret_type, &type_params); + // Type final_ret = Unify(gen_ret, unified->ret_type, f->span); - return FuncTypeNode::make(arg_types, final_ret, type_params, {}); + // return FuncTypeNode::make(arg_types, final_ret, type_params, {}); // for (auto param : f->params) { // GetType(param); @@ -509,6 +522,8 @@ class TypeInferencer::Resolver : public ExprMutator { auto it = tmap_.find(GetRef(op)); CHECK(it != tmap_.end()); Type checked_type = solver_->Resolve(it->second.checked_type); + Generalizer gen; + checked_type = gen.Generalize(checked_type); CHECK(checked_type.as() == nullptr) << "Cannot resolve type of " << GetRef(op) << " at " << op->span; @@ -605,6 +620,7 @@ Expr TypeInferencer::Infer(Expr expr) { GetType(expr); // Step 1: Solve the constraints. solver_.Solve(); + // Step 2: Attach resolved types to checked_type field. auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr); CHECK(WellFormed(resolved_expr)); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index be97fe3e66b3..8e4fac2709e1 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -131,13 +131,11 @@ class TypeSolver::Unifier : // instantiate away all type parameters in a function type FuncType InstantiateFuncType(const FuncType& ft) { - Map subst_map; for (auto type_param : ft->type_params) { - Type hole = InstantiateTypeVar(type_param); - subst_map.Set(type_param, hole); + InstantiateTypeVar(type_param); } - Type transformed = Bind(ft, subst_map); + Type transformed = Bind(ft, tv_map_); auto* new_ft = transformed.as(); // drop the type param list altogether return FuncTypeNode::make(new_ft->arg_types, From 599d4a64f09d14a7cf96ffe0e09ff1b71dc41073 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 7 Dec 2018 16:56:23 -0800 Subject: [PATCH 20/55] Correct error in type var instantiation --- src/relay/pass/type_solver.cc | 11 +++++++++-- tests/python/relay/test_type_solver.py | 16 +++++++++++++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 8e4fac2709e1..ce056a678b28 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -120,7 +120,7 @@ class TypeSolver::Unifier : TypeVar tv = GetRef(tvn); auto it = tv_map_.find(tv); - if (tv_map_.find(tv) != tv_map_.end()) { + if (it != tv_map_.end()) { return (*it).second; } @@ -135,7 +135,14 @@ class TypeSolver::Unifier : InstantiateTypeVar(type_param); } - Type transformed = Bind(ft, tv_map_); + // to avoid error when substituting type vars (TypeMutator + // errors out if the type var list in a FuncType contains an + // IncompleteType) + FuncType strip_tvs = FuncTypeNode::make(ft->arg_types, + ft->ret_type, {}, + ft->type_constraints); + Type transformed = Bind(strip_tvs, tv_map_); + auto* new_ft = transformed.as(); // drop the type param list altogether return FuncTypeNode::make(new_ft->arg_types, diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index 1e545f8dd6aa..3ebe4bdfe425 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -116,6 +116,21 @@ def test_unify_vars_under_tuples(): assert (unified == tup1 or unified == tup2) +def test_instantiation_of_typevars(): + solver = make_solver() + + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + + a = relay.ty.TypeVar('a') + b = relay.ty.TypeVar('b') + + ft1 = relay.ty.FuncType([t1], t2) + ft2 = relay.ty.FuncType([a], b, [a, b]) + unified = solver.Unify(ft1, ft2) + assert (unified == solver.Resolve(ft1)) + + def test_recursive_backward_solving(): solver = make_solver() @@ -166,7 +181,6 @@ def test_backward_solving_after_child_update(): assert solver.Resolve(t4) == tup_concrete assert solver.Resolve(t5) == tup_concrete - @raises(tvm._ffi.base.TVMError) def test_incompatible_tuple_unification(): solver = make_solver() From 600d024e07911dbf57e3dd8530c3fed1fd953612 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Sun, 9 Dec 2018 18:33:27 -0800 Subject: [PATCH 21/55] Do not permit free type variables in tests --- tests/cpp/relay_pass_type_infer_test.cc | 14 +++++++++++--- tests/python/relay/test_type_infer.py | 12 ++++++++---- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index 385bde974014..3509450d98bd 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -9,10 +9,18 @@ TEST(Relay, SelfReference) { auto type_a = relay::TypeVarNode::make("a", relay::TypeVarNode::kType); auto type_b = relay::TypeVarNode::make("b", relay::TypeVarNode::kType); auto x = relay::VarNode::make("x", type_a); - auto f = relay::FunctionNode::make(tvm::Array{ x }, x, type_b, Array{}); - auto fx = relay::CallNode::make(f, Array{ x }); + auto f = relay::FunctionNode::make(tvm::Array{ x }, x, type_b, + Array{type_a, type_b}); + + auto y = relay::VarNode::make("y", type_a); + auto call = relay::CallNode::make(f, Array{ y }); + auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, type_b, + Array{type_a, type_b}); auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map{})); - CHECK_EQ(type_fx->checked_type(), type_a); + + auto expected = relay::FuncTypeNode::make(tvm::Array{ type_a }, type_a, + relay::Array{type_a} , {}); + CHECK_EQ(type_fx->checked_type(), expected); } int main(int argc, char ** argv) { diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 6d1cd134a6eb..b159c94ddddb 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -218,13 +218,17 @@ def f(x) { a = relay.TypeVar("a") b = relay.TypeVar("b") x = relay.var("x", a) + y = relay.var("y", a) sb = relay.ScopeBuilder() f = relay.Function([x], x, b, [a, b]) - fx = relay.Call(f, [x]) - assert relay.ir_pass.infer_type(x).checked_type == a - assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a) - assert relay.ir_pass.infer_type(fx).checked_type == a + fx = relay.Function([y], relay.Call(f, [y])) + + x_type = relay.ir_pass.infer_type(x).checked_type + f_type = relay.ir_pass.infer_type(f).checked_type + call_type = relay.ir_pass.infer_type(fx).checked_type + assert f_type == relay.FuncType([a], a, [a]) + assert call_type == relay.FuncType([a], a, [a]) def test_global_var_recursion(): From face1132d25615955bf8af875b71dec73a6db720 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Sun, 9 Dec 2018 18:34:15 -0800 Subject: [PATCH 22/55] Remove commented-out function inference code --- src/relay/pass/type_infer.cc | 50 ------------------------------------ 1 file changed, 50 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 80935d12dfa7..48faf8ac975a 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -416,56 +416,6 @@ class TypeInferencer : private ExprFunctor { f->type_params, {}); return Unify(incompleteFuncType, candidateFuncType, f->span); - - // // generalize remaining incomplete types - // Generalizer gen; - // Array arg_types; - // Array type_params; - // for (auto param : unified->type_params) { - // Type gen_param = gen.Generalize(param, &type_params); - // Type atype = Unify(param, gen_param, f->span); - // arg_types.push_back(atype); - // } - - // Type gen_ret = gen.Generalize(unified->ret_type, &type_params); - // Type final_ret = Unify(gen_ret, unified->ret_type, f->span); - - // return FuncTypeNode::make(arg_types, final_ret, type_params, {}); - - // for (auto param : f->params) { - // GetType(param); - // } - // Type rtype = GetType(f->body); - // if (f->ret_type.defined()) { - // rtype = this->Unify(f->ret_type, rtype, f->span); - // } - - // // Run solver using the currently known information - // solver_.Solve(); - // // Trying to resolve - // Array arg_types; - // Array type_params = f->type_params; - // Generalizer gen; - - // for (size_t i = 0; i < f->params.size(); ++i) { - // Type atype = solver_.Resolve(GetType(f->params[i])); - // // CHECK(atype.as() == nullptr) - // // << "Cannot resolve type of " << i - // // << "-th parameter of function at" << f->span; - // Type gen_atype = gen.Generalize(atype, &type_params); - // atype = this->Unify(atype, gen_atype, f->span); - // arg_types.push_back(atype); - // } - - // rtype = solver_.Resolve(rtype); - // Type gen_rtype = gen.Generalize(rtype, &type_params); - // this->Unify(rtype, gen_rtype, f->span); - // rtype = gen_rtype; - - // // CHECK(rtype.as() == nullptr) - // // << "Cannot resolve return type of function at" << f->span; - // // do not support constraint lifting for now. - // return FuncTypeNode::make(arg_types, rtype, type_params, {}); } }; From aa9c837b3704922e564b3bd8f65cce622fb05cad Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Sun, 9 Dec 2018 18:36:26 -0800 Subject: [PATCH 23/55] Instantiate generalizer once in type inference --- src/relay/pass/type_infer.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 48faf8ac975a..ba07172bc5e7 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -423,7 +423,7 @@ class TypeInferencer::Resolver : public ExprMutator { public: Resolver(const std::unordered_map& tmap, TypeSolver* solver) - : tmap_(tmap), solver_(solver) { + : tmap_(tmap), solver_(solver), gen_(Generalizer()) { } Expr VisitExpr_(const VarNode* op) final { @@ -472,8 +472,7 @@ class TypeInferencer::Resolver : public ExprMutator { auto it = tmap_.find(GetRef(op)); CHECK(it != tmap_.end()); Type checked_type = solver_->Resolve(it->second.checked_type); - Generalizer gen; - checked_type = gen.Generalize(checked_type); + checked_type = gen_.Generalize(checked_type); CHECK(checked_type.as() == nullptr) << "Cannot resolve type of " << GetRef(op) << " at " << op->span; @@ -559,6 +558,7 @@ class TypeInferencer::Resolver : public ExprMutator { private: const std::unordered_map& tmap_; TypeSolver* solver_; + Generalizer gen_; // whether attach the checked type as type_annotation // if original type anntation is missing. bool update_missing_type_annotation_{true}; From 59eabc72a569f1ca234af1451db95a268db55d1f Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Sun, 9 Dec 2018 19:02:07 -0800 Subject: [PATCH 24/55] Fix error in type inference cpp test --- tests/cpp/relay_pass_type_infer_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index 3509450d98bd..55323b3ce9e8 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -15,11 +15,11 @@ TEST(Relay, SelfReference) { auto y = relay::VarNode::make("y", type_a); auto call = relay::CallNode::make(f, Array{ y }); auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, type_b, - Array{type_a, type_b}); + tvm::Array{type_a, type_b}); auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map{})); auto expected = relay::FuncTypeNode::make(tvm::Array{ type_a }, type_a, - relay::Array{type_a} , {}); + tvm::Array{type_a} , {}); CHECK_EQ(type_fx->checked_type(), expected); } From 32240e07fedb8b980fd61468e4467967e9d0393e Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Sun, 9 Dec 2018 19:05:30 -0800 Subject: [PATCH 25/55] Use alpha equality in cpp test --- tests/cpp/relay_pass_type_infer_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index 55323b3ce9e8..cc073f28f56f 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -20,7 +20,7 @@ TEST(Relay, SelfReference) { auto expected = relay::FuncTypeNode::make(tvm::Array{ type_a }, type_a, tvm::Array{type_a} , {}); - CHECK_EQ(type_fx->checked_type(), expected); + CHECK(AlphaEqual(type_fx->checked_type(), expected)); } int main(int argc, char ** argv) { From a7e6fa6513444e869d135a3de784c55a85b07e18 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 10 Dec 2018 17:36:43 -0800 Subject: [PATCH 26/55] Use free var pass for collecting type params in generalization. Fix double-counting bug in free var pass --- include/tvm/relay/pass.h | 11 +++++++++++ src/relay/pass/type_infer.cc | 9 ++++----- src/relay/pass/util.cc | 1 + 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 1897809f48b8..f89064912339 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -130,6 +130,17 @@ tvm::Array FreeVars(const Expr& expr); */ tvm::Array FreeTypeVars(const Expr& expr); +/*! \brief Get free TypeVars from type t. + * + * Free type parameters are type parameters that are not bound by a function + * type in the context. + * + * \param expr the expression. + * + * \return List of free type vars, in the PostDFS order visited by type. + */ +tvm::Array FreeTypeVars(const Type& t); + /*! \brief Remove expressions which does not effect the program result. * * It will remove let bindings which are not referenced, and branches that will diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index ba07172bc5e7..d046a3c890f2 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -75,7 +75,7 @@ struct ResolvedTypeInfo { // Converts incomplete types remaining in function signature to type vars class Generalizer : public TypeMutator { public: - Generalizer() : subst_map_({}), vars_({}), varno_(0) {} + Generalizer() : subst_map_({}), varno_(0) {} // turns each distinct incomplete type into a type var and returns // the transformed type with an array of all type vars present @@ -87,8 +87,9 @@ class Generalizer : public TypeMutator { return ret; } - // for a func type, we add the type vars to the list at top - return FuncTypeNode::make(ftn->arg_types, ftn->ret_type, vars_, ftn->type_constraints); + // for a func type, we generalize at the top level + Array free_vars = FreeTypeVars(GetRef(ftn)); + return FuncTypeNode::make(ftn->arg_types, ftn->ret_type, free_vars, ftn->type_constraints); } Type VisitType_(const IncompleteTypeNode *op) override { @@ -103,7 +104,6 @@ class Generalizer : public TypeMutator { ss << "_var_" << varno_; varno_++; TypeVar new_var = TypeVarNode::make(ss.str(), TypeVarNode::Kind::kType); - vars_.push_back(new_var); subst_map_.Set(t, new_var); return new_var; } @@ -122,7 +122,6 @@ class Generalizer : public TypeMutator { private: tvm::Map subst_map_; - Array vars_; int varno_; }; diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index b99d975135be..2f8c578de078 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -24,6 +24,7 @@ class FreeTypeVarTVisitor : public TypeVisitor { TypeVar var = GetRef(tp); if (bound_vars_->count(var) == 0) { free_vars_->push_back(var); + bound_vars_->insert(var); } } From 299574df03578130f43a77c3071aa4fa1982409f Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 10 Dec 2018 19:00:14 -0800 Subject: [PATCH 27/55] Handle recursion in let, generalize early in that case --- src/relay/pass/type_infer.cc | 14 +++++++- tests/python/relay/test_type_infer.py | 50 +++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index d046a3c890f2..9861aef2f725 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -149,6 +149,8 @@ class TypeInferencer : private ExprFunctor { class Resolver; // internal environment Module mod_; + // Generalizer for handling let nodes + Generalizer gen_; // map from expression to checked type // type inferencer will populate it up std::unordered_map type_map_; @@ -234,11 +236,21 @@ class TypeInferencer : private ExprFunctor { } Type VisitExpr_(const LetNode* op) final { + // if the definition is a function literal, permit recursion + bool isFunctionLiteral = op->value.as() != nullptr; + if (isFunctionLiteral) { + type_map_[op->var].checked_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + } + Type vtype = GetType(op->value); + // need to generalize inner functions immediately (per H-M) + if (isFunctionLiteral) { + vtype = gen_.Generalize(vtype); + } if (op->var->type_annotation.defined()) { vtype = Unify(vtype, op->var->type_annotation, op->span); } - CHECK(!type_map_.count(op->var)); + CHECK(isFunctionLiteral || !type_map_.count(op->var)); // NOTE: no scoping is necessary because var are unique in program type_map_[op->var].checked_type = vtype; return GetType(op->body); diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index b159c94ddddb..ef7a5249baf2 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -231,6 +231,56 @@ def f(x) { assert call_type == relay.FuncType([a], a, [a]) +def test_nested_recursive_function(): + """ + Program: + def f(x) { + let g = fun(x) { g(x) }; + g(x) + } + """ + x = relay.var("x") + y = relay.var("y") + g = relay.var("g") + f = relay.Function([x], + relay.Let(g, + relay.Function( + [y], relay.Call(g, [y])), + relay.Call(g, [x]))) + + a = relay.TypeVar("a") + b = relay.TypeVar("b") + f_type = relay.ir_pass.infer_type(f).checked_type + assert f_type == relay.FuncType([a], b, [a, b]) + + +def test_proper_inner_function_generalization(): + """ + Program: + def f() { + let id = fun(x) { x }; + let unit = id(()); + let idid = id(id); + unit + } + """ + x = relay.var("x") + unit = relay.var("unit") + id1 = relay.var("id") + id2 = relay.var("idid") + f = relay.Function( + [], + relay.Let(id1, relay.Function([x], x), + relay.Let( + unit, relay.Call(id1, [relay.Tuple([])]), + relay.Let( + id2, relay.Call(id1, [id1]), + unit)))) + + f_type = relay.ir_pass.infer_type(f).checked_type + assert f_type == relay.FuncType([], relay.TupleType([])) + + def test_global_var_recursion(): mod = relay.Module({}) gv = relay.GlobalVar("foo") From 0dcd205e8c9aefea5837d8717145a7285564b9a3 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 10 Dec 2018 19:14:17 -0800 Subject: [PATCH 28/55] Fix FreeVar pass doc --- include/tvm/relay/pass.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index f89064912339..28a123bdc1e6 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -135,7 +135,7 @@ tvm::Array FreeTypeVars(const Expr& expr); * Free type parameters are type parameters that are not bound by a function * type in the context. * - * \param expr the expression. + * \param t the type. * * \return List of free type vars, in the PostDFS order visited by type. */ From ebf968837bbc246e00c0651e721ec366788dc5ff Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 12 Dec 2018 14:28:44 -0800 Subject: [PATCH 29/55] Refactoring of free var visitors to ensure fixed order (mostly @jroesch's work) --- include/tvm/relay/pass.h | 57 +++++++++++ src/relay/pass/util.cc | 213 ++++++++++++++++++++++++++++++++------- 2 files changed, 231 insertions(+), 39 deletions(-) diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 28a123bdc1e6..0d3697a5b507 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -108,6 +108,17 @@ bool AlphaEqual(const Type& t1, const Type& t2); */ bool WellFormed(const Expr& expr); +/*! \brief Get all bound variables from expression expr. + * + * Bound variables are all variables that are declared in the expr. + * They only have meaning inside that expr, and can only be used in it. + * + * \param expr the expression. + * + * \return List of bound vars, in the PostDFS order in the expression. + */ +tvm::Array BoundVars(const Expr& expr); + /*! \brief Get free type parameters from expression expr. * * Free variables are variables that are not bound by a @@ -119,6 +130,14 @@ bool WellFormed(const Expr& expr); */ tvm::Array FreeVars(const Expr& expr); +/*! \brief Get all variables from expression expr. + * + * \param expr the expression. + * + * \return List of free vars, in the PostDFS order in the expression. + */ +tvm::Array AllVars(const Expr& expr); + /*! \brief Get free TypeVars from expression expr. * * Free type parameters are type parameters that are not bound by a function @@ -141,6 +160,44 @@ tvm::Array FreeTypeVars(const Expr& expr); */ tvm::Array FreeTypeVars(const Type& t); +/*! \brief Get all bound type variables from expression expr. + * + * Bound variables are all type variables that are declared in the expr. + * They only have meaning inside that expr, and can only be used in it. + * + * \param expr the expression. + * + * \return List of bound type vars, in the PostDFS order in the expression. + */ +tvm::Array BoundTypeVars(const Expr& expr); + +/*! \brief Get all bound type variables from type t. + * + * Bound variables are all type variables that are declared in the type. + * They only have meaning inside that type, and can only be used in it. + * + * \param t the type + * + * \return List of bound type vars, in the PostDFS order visited by type. + */ +tvm::Array BoundTypeVars(const Type& t); + +/*! \brief Get all type variables in expression expr. + * + * \param expr the expression. + * + * \return List of type vars, in the PostDFS order in the expression. + */ +tvm::Array AllTypeVars(const Expr& expr); + +/*! \brief Get all type variables in type t. + * + * \param t the type. + * + * \return List of type vars, in the PostDFS order visited by type. + */ +tvm::Array AllTypeVars(const Type& t); + /*! \brief Remove expressions which does not effect the program result. * * It will remove let bindings which are not referenced, and branches that will diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 2f8c578de078..7530675ecce2 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -12,106 +12,211 @@ namespace tvm { namespace relay { -// FreeTypeVar -class FreeTypeVarTVisitor : public TypeVisitor { +template +struct InsertionSet { + std::unordered_set set; + std::vector data; + void Insert(const T& t) { + if (set.count(t) == 0) { + set.insert(t); + data.push_back(t); + } + } +}; + +class TypeVarTVisitor : public TypeVisitor { public: - FreeTypeVarTVisitor( - Array* free_vars, - std::unordered_set* bound_vars) - : free_vars_(free_vars), bound_vars_(bound_vars) { } + TypeVarTVisitor( + InsertionSet* type_vars, + InsertionSet* bound_type_vars) + : type_vars_(type_vars), bound_type_vars_(bound_type_vars) { } void VisitType_(const TypeVarNode* tp) final { TypeVar var = GetRef(tp); - if (bound_vars_->count(var) == 0) { - free_vars_->push_back(var); - bound_vars_->insert(var); - } + type_vars_->Insert(var); } void VisitType_(const FuncTypeNode* f) final { for (auto type_param : f->type_params) { - bound_vars_->insert(type_param); + type_vars_->Insert(type_param); + bound_type_vars_->Insert(type_param); } TypeVisitor::VisitType_(f); } private: - Array* free_vars_; - std::unordered_set* bound_vars_; + InsertionSet* type_vars_; + InsertionSet* bound_type_vars_; }; -class FreeTypeVarEVisitor : private ExprVisitor { +class TypeVarEVisitor : private ExprVisitor { public: - Array Find(const Expr& expr) { - this->VisitExpr(expr); - return free_vars_; + Array CollectFree() { + Array ret; + for (const auto& v : type_vars_.data) { + if (bound_type_vars_.set.count(v) == 0) { + ret.push_back(v); + } + } + return ret; + } + + Array CollectBound() { + Array ret; + for (const auto& v : bound_type_vars_.data) { + ret.push_back(v); + } + return ret; + } + + Array CollectAll() { + Array ret; + for (const auto& v : bound_type_vars_.data) { + ret.push_back(v); + } + return ret; } - Array Find(const Type& type) { - this->VisitType(type); - return free_vars_; + Array Free(const Expr& expr) { + VisitExpr(expr); + return CollectFree(); + } + + Array Free(const Type& type) { + VisitType(type); + return CollectFree(); + } + + Array Bound(const Expr& expr) { + VisitExpr(expr); + return CollectBound(); + } + + Array Bound(const Type& type) { + VisitType(type); + return CollectBound(); + } + + Array All(const Expr& expr) { + VisitExpr(expr); + return CollectAll(); + } + + Array All(const Type& type) { + VisitType(type); + return CollectAll(); } void VisitExpr_(const FunctionNode* f) final { for (const auto& tp : f->type_params) { - bound_vars_.insert(tp); + type_vars_.Insert(tp); + bound_type_vars_.Insert(tp); } ExprVisitor::VisitExpr_(f); } void VisitType(const Type& t) final { - FreeTypeVarTVisitor(&free_vars_, &bound_vars_) + TypeVarTVisitor(&type_vars_, &bound_type_vars_) .VisitType(t); } private: - // The result list - Array free_vars_; - std::unordered_set bound_vars_; + InsertionSet type_vars_; + InsertionSet bound_type_vars_; }; -class FreeVarVisitor : protected ExprVisitor { +class VarVisitor : protected ExprVisitor { public: - Array Find(const Expr& expr) { + Array Free(const Expr& expr) { this->VisitExpr(expr); - return free_vars_; + Array ret; + for (const auto& v : vars_.data) { + if (bound_vars_.set.count(v) == 0) { + ret.push_back(v); + } + } + return ret; } - void VisitExpr_(const VarNode* var) final { - if (bound_vars_.count(var) == 0) { - free_vars_.push_back(GetRef(var)); + Array Bound(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : bound_vars_.data) { + ret.push_back(v); } + return ret; + } + + Array All(const Expr& expr) { + this->VisitExpr(expr); + Array ret; + for (const auto& v : vars_.data) { + ret.push_back(v); + } + return ret; + } + + void Bounded(const Var& v) { + bound_vars_.Insert(v); + vars_.Insert(v); + } + + void VisitExpr_(const VarNode* var) final { + vars_.Insert(GetRef(var)); } void VisitExpr_(const FunctionNode* op) final { for (const auto& param : op->params) { - bound_vars_.insert(param.operator->()); + Bounded(param); } VisitExpr(op->body); } void VisitExpr_(const LetNode* op) final { - bound_vars_.insert(op->var.operator->()); + Bounded(op->var); VisitExpr(op->value); VisitExpr(op->body); } private: - // The result list - Array free_vars_; - std::unordered_set bound_vars_; + InsertionSet vars_; + InsertionSet bound_vars_; }; tvm::Array FreeTypeVars(const Expr& expr) { - return FreeTypeVarEVisitor().Find(expr); + return TypeVarEVisitor().Free(expr); } tvm::Array FreeTypeVars(const Type& type) { - return FreeTypeVarEVisitor().Find(type); + return TypeVarEVisitor().Free(type); +} + +tvm::Array BoundTypeVars(const Expr& expr) { + return TypeVarEVisitor().Bound(expr); +} + +tvm::Array BoundTypeVars(const Type& type) { + return TypeVarEVisitor().Bound(type); +} + +tvm::Array AllTypeVars(const Expr& expr) { + return TypeVarEVisitor().All(expr); +} + +tvm::Array AllTypeVars(const Type& type) { + return TypeVarEVisitor().All(type); } tvm::Array FreeVars(const Expr& expr) { - return FreeVarVisitor().Find(expr); + return VarVisitor().Free(expr); +} + +tvm::Array BoundVars(const Expr& expr) { + return VarVisitor().Bound(expr); +} + +tvm::Array AllVars(const Expr& expr) { + return VarVisitor().All(expr); } TVM_REGISTER_API("relay._ir_pass.free_vars") @@ -119,6 +224,16 @@ TVM_REGISTER_API("relay._ir_pass.free_vars") *ret = FreeVars(args[0]); }); +TVM_REGISTER_API("relay._ir_pass.bound_vars") + .set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = BoundVars(args[0]); + }); + +TVM_REGISTER_API("relay._ir_pass.all_vars") + .set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = AllVars(args[0]); + }); + TVM_REGISTER_API("relay._ir_pass.free_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; @@ -129,6 +244,26 @@ TVM_REGISTER_API("relay._ir_pass.free_type_vars") } }); +TVM_REGISTER_API("relay._ir_pass.bound_type_vars") + .set_body([](TVMArgs args, TVMRetValue* ret) { + NodeRef x = args[0]; + if (x.as()) { + *ret = BoundTypeVars(Downcast(x)); + } else { + *ret = BoundTypeVars(Downcast(x)); + } + }); + +TVM_REGISTER_API("relay._ir_pass.all_type_vars") + .set_body([](TVMArgs args, TVMRetValue* ret) { + NodeRef x = args[0]; + if (x.as()) { + *ret = AllTypeVars(Downcast(x)); + } else { + *ret = AllTypeVars(Downcast(x)); + } + }); + /*! * \brief Get reference counter of each internal ExprNode in body. * \param body The body expression. From 6b15cf793673d8187ed69f1cac9fdb4c50d1ea15 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 12 Dec 2018 15:50:38 -0800 Subject: [PATCH 30/55] Add tests for various variable collection passes, add Python handles for passes, fix bug in all vars pass --- python/tvm/relay/ir_pass.py | 68 +++++++++- src/relay/pass/util.cc | 8 +- tests/python/relay/test_pass_free_vars.py | 41 ------ tests/python/relay/test_pass_vars.py | 144 ++++++++++++++++++++++ 4 files changed, 214 insertions(+), 47 deletions(-) delete mode 100644 tests/python/relay/test_pass_free_vars.py create mode 100644 tests/python/relay/test_pass_vars.py diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 1bec7ccd72d5..d5d5e9261fc7 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -158,6 +158,38 @@ def free_vars(expr): return _ir_pass.free_vars(expr) +def bound_vars(expr): + """Get bound vars from expression expr in post-DFS order. + + Parameters + ---------- + expr: tvm.relay.Expr + The input expression + + Returns + ------- + free : List[tvm.relay.Var] + The list of bound variables in post-DFS order. + """ + return _ir_pass.bound_vars(expr) + + +def all_vars(expr): + """Get all vars from expression expr in post-DFS order. + + Parameters + ---------- + expr: tvm.relay.Expr + The input expression + + Returns + ------- + free : List[tvm.relay.Var] + The list of all variables in post-DFS order. + """ + return _ir_pass.all_vars(expr) + + def free_type_vars(expr): """Get free type variables from expression/type e @@ -168,12 +200,44 @@ def free_type_vars(expr): Returns ------- - free : List[tvm.relay.TypeParam] - The list of free type variables + free : List[tvm.relay.TypeVar] + The list of free type variables in post-DFS order """ return _ir_pass.free_type_vars(expr) +def bound_type_vars(expr): + """Get bound type variables from expression/type e + + Parameters + ---------- + expr: Union[tvm.relay.Expr,tvm.relay.Type] + The input expression/type + + Returns + ------- + free : List[tvm.relay.TypeVar] + The list of bound type variables in post-DFS order + """ + return _ir_pass.bound_type_vars(expr) + + +def all_type_vars(expr): + """Get all type variables from expression/type e + + Parameters + ---------- + expr: Union[tvm.relay.Expr,tvm.relay.Type] + The input expression/type + + Returns + ------- + free : List[tvm.relay.TypeVar] + The list of all type variables in post-DFS order + """ + return _ir_pass.all_type_vars(expr) + + def simplify_inference(expr): """ Simplify the data-flow graph for inference phase. diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 7530675ecce2..03b72831e239 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -71,7 +71,7 @@ class TypeVarEVisitor : private ExprVisitor { Array CollectAll() { Array ret; - for (const auto& v : bound_type_vars_.data) { + for (const auto& v : type_vars_.data) { ret.push_back(v); } return ret; @@ -237,7 +237,7 @@ TVM_REGISTER_API("relay._ir_pass.all_vars") TVM_REGISTER_API("relay._ir_pass.free_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; - if (x.as()) { + if (x.as_derived()) { *ret = FreeTypeVars(Downcast(x)); } else { *ret = FreeTypeVars(Downcast(x)); @@ -247,7 +247,7 @@ TVM_REGISTER_API("relay._ir_pass.free_type_vars") TVM_REGISTER_API("relay._ir_pass.bound_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; - if (x.as()) { + if (x.as_derived()) { *ret = BoundTypeVars(Downcast(x)); } else { *ret = BoundTypeVars(Downcast(x)); @@ -257,7 +257,7 @@ TVM_REGISTER_API("relay._ir_pass.bound_type_vars") TVM_REGISTER_API("relay._ir_pass.all_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[0]; - if (x.as()) { + if (x.as_derived()) { *ret = AllTypeVars(Downcast(x)); } else { *ret = AllTypeVars(Downcast(x)); diff --git a/tests/python/relay/test_pass_free_vars.py b/tests/python/relay/test_pass_free_vars.py deleted file mode 100644 index 151dbe1412bc..000000000000 --- a/tests/python/relay/test_pass_free_vars.py +++ /dev/null @@ -1,41 +0,0 @@ -import tvm -from tvm import relay -from tvm.relay.ir_pass import free_vars, free_type_vars - -def test_free_vars(): - ty = relay.TensorType([], "int32") - x = relay.Var("x", ty) - fvx = free_vars(x) - assert len(fvx) == 1 - assert fvx[0] == x - v = relay.Constant(tvm.nd.array(10)) - - let = relay.Let(x, v, x) - fvx = free_vars(let) - assert len(free_vars(let)) == 0 - f = relay.Function([x], x, ty) - assert len(free_vars(f)) == 0 - - -def test_tuple(): - t = relay.Var('t') - fv = free_vars(relay.Tuple([t, t])) - assert len(fv) == 1 - assert fv[0] == t - fv = free_vars(relay.TupleGetItem(t, 123)) - assert len(fv) == 1 - assert fv[0] == t - - -def test_free_type_vars(): - tp = relay.TypeVar("") - ty = relay.TupleType([tp, relay.TensorType([], "int32")]) - x = relay.Var("x", ty) - y = relay.Var("y") - let = relay.Let(x, y, x) - fvl = free_vars(let) - assert len(fvl) == 1 - assert fvl[0] == y - ftvl = free_type_vars(let) - assert len(ftvl) == 1 - assert ftvl[0] == tp diff --git a/tests/python/relay/test_pass_vars.py b/tests/python/relay/test_pass_vars.py new file mode 100644 index 000000000000..c8d3d6d14992 --- /dev/null +++ b/tests/python/relay/test_pass_vars.py @@ -0,0 +1,144 @@ +import tvm +from tvm import relay +from tvm.relay.ir_pass import (free_vars, free_type_vars, + bound_vars, bound_type_vars, + all_vars, all_type_vars) + +def assert_vars_match(actual, expected): + assert len(actual) == len(expected) + for i in range(len(actual)): + assert actual[i] == expected[i] + + +def test_free_vars(): + ty = relay.TensorType([], "int32") + x = relay.Var("x", ty) + fvx = free_vars(x) + assert len(fvx) == 1 + assert fvx[0] == x + v = relay.Constant(tvm.nd.array(10)) + + let = relay.Let(x, v, x) + fvx = free_vars(let) + assert len(free_vars(let)) == 0 + f = relay.Function([x], x, ty) + assert len(free_vars(f)) == 0 + + +def test_free_vars_tuple(): + t = relay.Var('t') + fv = free_vars(relay.Tuple([t, t])) + assert len(fv) == 1 + assert fv[0] == t + fv = free_vars(relay.TupleGetItem(t, 123)) + assert len(fv) == 1 + assert fv[0] == t + + +def test_free_type_vars(): + tp = relay.TypeVar("") + ty = relay.TupleType([tp, relay.TensorType([], "int32")]) + x = relay.Var("x", ty) + y = relay.Var("y") + let = relay.Let(x, y, x) + fvl = free_vars(let) + assert len(fvl) == 1 + assert fvl[0] == y + ftvl = free_type_vars(let) + assert len(ftvl) == 1 + assert ftvl[0] == tp + + +def test_bound_vars(): + x = relay.Var("x") + y = relay.Var("y") + z = relay.Var("z") + a = relay.Var("a") + + f1 = relay.Function([x, y, z], relay.Let(a, x, relay.Tuple([]))) + assert_vars_match(bound_vars(f1), [x, y, z, a]) + + tup = relay.Tuple([x, y, z, a]) + assert len(bound_vars(tup)) == 0 + + f2 = relay.Function([x, y], relay.Tuple([x, y, z, a])) + assert_vars_match(bound_vars(f2), [x, y]) + + +def test_bound_type_vars(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + c = relay.TypeVar("c") + + ft1 = relay.FuncType([a], b, [a, b]) + bound_ft1 = bound_type_vars(ft1) + assert_vars_match(bound_type_vars(ft1), [a, b]) + + ft2 = relay.FuncType([], c, [a]) + assert_vars_match(bound_type_vars(ft2), [a]) + + tup_ty = relay.TupleType([a, b, c]) + assert len(bound_type_vars(tup_ty)) == 0 + + f1 = relay.Function([], relay.Tuple([]), type_params=[a, b]) + assert_vars_match(bound_type_vars(f1), [a, b]) + + f2 = relay.Function([], relay.Tuple([]), c) + assert len(bound_type_vars(f2)) == 0 + + x = relay.Var("x", a) + let1 = relay.Let(x, relay.Tuple([]), x) + assert len(bound_type_vars(let1)) == 0 + + let2 = relay.Let(x, relay.Function([], relay.Tuple([]), type_params=[b, c]), x) + assert_vars_match(bound_type_vars(let2), [b, c]) + + +def test_all_vars(): + x = relay.Var("x") + y = relay.Var("y") + z = relay.Var("z") + + f1 = relay.Function([x, y], z) + assert_vars_match(all_vars(f1), [x, y, z]) + + f2 = relay.Function([x], relay.Let(y, relay.Tuple([]), z)) + assert_vars_match(all_vars(f2), [x, y, z]) + + f3 = relay.Function([x], relay.Tuple([y, z])) + assert_vars_match(all_vars(f3), [x, y, z]) + + tup = relay.Tuple([x, y, z]) + assert_vars_match(all_vars(tup), [x, y, z]) + + +def test_all_type_vars(): + a = relay.TypeVar("a") + b = relay.TypeVar("b") + c = relay.TypeVar("c") + + ft1 = relay.FuncType([b], c, [a]) + assert_vars_match(all_type_vars(ft1), [a, b, c]) + + ft2 = relay.FuncType([], relay.TupleType([a, b, c]), []) + assert_vars_match(all_type_vars(ft2), [a, b, c]) + + w = relay.Var("w") + x = relay.Var("x", a) + y = relay.Var("y", b) + z = relay.Var("z", c) + + f1 = relay.Function([x], y, b, [a]) + assert_vars_match(all_type_vars(f1), [a, b]) + + f2 = relay.Function([x], relay.Let(y, x, z)) + assert_vars_match(all_type_vars(f2), [a, b, c]) + + f3 = relay.Function([], relay.Tuple([x, y, z]), ret_type=relay.TupleType([a, b, c])) + assert_vars_match(all_type_vars(f3), [a, b, c]) + + f4 = relay.Function([w], relay.Tuple([]), type_params=[a, b, c]) + assert_vars_match(all_type_vars(f4), [a, b, c]) + + f5 = relay.Function([w], w) + assert len(all_type_vars(f5)) == 0 From 6ae595a3dc29391cf0c0a6a339ac7d5c8a083522 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 17 Dec 2018 13:44:54 -0800 Subject: [PATCH 31/55] Unifier should not keep additional state, ensure that type params *match* in func types --- include/tvm/relay/type.h | 34 ---------- src/relay/pass/type_solver.cc | 89 +++++++++----------------- src/relay/pass/type_solver.h | 2 - tests/python/relay/test_type_solver.py | 33 ++++++++-- 4 files changed, 57 insertions(+), 101 deletions(-) diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index ad62d1d78571..69a8a4fb0bd7 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -317,40 +317,6 @@ class TypeReporter : public NodeRef { using ContainerType = TypeReporterNode; }; -class TypeUnifier; -/*! - * \brief Data structure that handles type unification. - */ -class TypeUnifierNode : public Node { - public: - /*! - * \brief Returns a unified type based on the two arguments. - */ - TVM_DLL virtual Type Unify(const Type& dst, const Type& src) = 0; - - // unifier is not serializable. - void VisitAttrs(tvm::AttrVisitor* v) final {} - - static constexpr const char* _type_key = "relay.TypeUnifier"; - TVM_DECLARE_NODE_TYPE_INFO(TypeUnifierNode, Node); -}; - -/*! - * \brief Container class of TypeUnifier. - * \sa TypeUnifierNode - */ -class TypeUnifier : public NodeRef { - public: - TypeUnifier() {} - explicit TypeUnifier(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) { - } - TypeUnifierNode* operator->() const { - return static_cast(node_.get()); - } - using ContainerType = TypeUnifierNode; -}; - - /*! * \brief User defined type constraint function. * diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index ce056a678b28..db79c30f49de 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -61,18 +61,15 @@ class TypeSolver::OccursChecker : public TypeVisitor { bool found_; }; -class TypeSolver::Unifier : - public TypeUnifierNode, public TypeFunctor { +class TypeSolver::Unifier : public TypeFunctor { public: - explicit Unifier(TypeSolver* solver) : solver_(solver), tv_map_({}) {} + explicit Unifier(TypeSolver* solver) : solver_(solver) {} - Type Unify(const Type& src, const Type& dst) final { + Type Unify(const Type& src, const Type& dst) { // Known limitation // - handle shape pattern matching - Type new_src = InstantiateTypeVar(src); - Type new_dst = InstantiateTypeVar(dst); - TypeNode* lhs = solver_->GetTypeNode(new_dst); - TypeNode* rhs = solver_->GetTypeNode(new_src); + TypeNode* lhs = solver_->GetTypeNode(dst); + TypeNode* rhs = solver_->GetTypeNode(src); // do occur check so we don't create self-referencing structure if (lhs->FindRoot() == rhs->FindRoot()) { @@ -102,7 +99,7 @@ class TypeSolver::Unifier : } } - // Checks whether lhs (taken to be a type var) occurs in t, meaning + // Checks whether lhs (taken to be a type hole) occurs in t, meaning // there is a recursive equality constraint, which should be rejected. // N.b.: A tautology like ?a = ?a is okay and should be checked for // *before* calling this method @@ -111,45 +108,6 @@ class TypeSolver::Unifier : return rc.Check(t); } - // if t is a type var, replace with an incomplete type - Type InstantiateTypeVar(const Type& t) { - auto* tvn = t.as(); - if (tvn == nullptr) { - return t; - } - - TypeVar tv = GetRef(tvn); - auto it = tv_map_.find(tv); - if (it != tv_map_.end()) { - return (*it).second; - } - - IncompleteType hole = IncompleteTypeNode::make(tv->kind); - tv_map_.Set(tv, hole); - return hole; - } - - // instantiate away all type parameters in a function type - FuncType InstantiateFuncType(const FuncType& ft) { - for (auto type_param : ft->type_params) { - InstantiateTypeVar(type_param); - } - - // to avoid error when substituting type vars (TypeMutator - // errors out if the type var list in a FuncType contains an - // IncompleteType) - FuncType strip_tvs = FuncTypeNode::make(ft->arg_types, - ft->ret_type, {}, - ft->type_constraints); - Type transformed = Bind(strip_tvs, tv_map_); - - auto* new_ft = transformed.as(); - // drop the type param list altogether - return FuncTypeNode::make(new_ft->arg_types, - new_ft->ret_type, {}, - new_ft->type_constraints); - } - // default: unify only if alpha-equal Type VisitTypeDefault_(const Node* op, const Type& tn) override { NodeRef nr = GetRef(op); @@ -181,37 +139,51 @@ class TypeSolver::Unifier : const auto* ftn = tn.as(); if (!ftn || op->arg_types.size() != ftn->arg_types.size() + || op->type_params.size() != ftn->type_params.size() || op->type_constraints.size() != ftn->type_constraints.size()) { return Type(nullptr); } - FuncType ft1 = InstantiateFuncType(GetRef(op)); - FuncType ft2 = InstantiateFuncType(GetRef(ftn)); + FuncType ft1 = GetRef(op); + FuncType ft2 = GetRef(ftn); + + // saves work if we can avoid remapping, etc. + if (AlphaEqual(ft1, ft2)) { + return ft1; + } + + // remap bound type params so we can compare by equality + Map subst_map; + for (size_t i = 0; i < ft1->type_params.size(); i++) { + subst_map.Set(ft2->type_params[i], ft1->type_params[i]); + } + + auto bound_ft2 = GetRef(Bind(ft2, subst_map).as()); - Type ret_type = Unify(ft1->ret_type, ft2->ret_type); + // now unify each field of the func types + Type ret_type = Unify(ft1->ret_type, bound_ft2->ret_type); std::vector arg_types; for (size_t i = 0; i < ft1->arg_types.size(); i++) { - Type arg_type = Unify(ft1->arg_types[i], ft2->arg_types[i]); + Type arg_type = Unify(ft1->arg_types[i], bound_ft2->arg_types[i]); arg_types.push_back(arg_type); } std::vector type_constraints; for (size_t i = 0; i < ft1->type_constraints.size(); i++) { Type unified_constraint = Unify(ft1->type_constraints[i], - ft2->type_constraints[i]); + bound_ft2->type_constraints[i]); const auto* tcn = unified_constraint.as(); CHECK(tcn) << "Two type constraints unified into a non-constraint?" - << ft1->type_constraints[i] << " and " << ft2->type_constraints[i]; + << ft1->type_constraints[i] << " and " << bound_ft2->type_constraints[i]; type_constraints.push_back(GetRef(tcn)); } - return FuncTypeNode::make(arg_types, ret_type, {}, type_constraints); + return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints); } private: TypeSolver* solver_; - Map tv_map_; }; class TypeSolver::Resolver : public TypeMutator { @@ -293,7 +265,7 @@ class TypeSolver::Propagator : public TypeFunctor { // constructor TypeSolver::TypeSolver() - : reporter_(make_node(this)), unifier_(make_node(this)) { + : reporter_(make_node(this)) { } // destructor @@ -309,7 +281,8 @@ TypeSolver::~TypeSolver() { // Add equality constraint Type TypeSolver::Unify(const Type& dst, const Type& src) { - return unifier_->Unify(dst, src); + Unifier unifier(this); + return unifier.Unify(dst, src); } // Add type constraint to the solver. diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index c6d5c20d737d..e756440452d5 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -136,8 +136,6 @@ class TypeSolver { common::Arena arena_; /*! \brief Reporter that reports back to self */ TypeReporter reporter_; - /*! \brief Data structure for unifying types */ - TypeUnifier unifier_; /*! * \brief GetTypeNode that is corresponds to t. * if it do not exist, create a new one. diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index 3ebe4bdfe425..12546e6b25b1 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -116,17 +116,20 @@ def test_unify_vars_under_tuples(): assert (unified == tup1 or unified == tup2) -def test_instantiation_of_typevars(): +def test_unify_functions_with_typevars(): solver = make_solver() - t1 = relay.ty.IncompleteType() - t2 = relay.ty.IncompleteType() - + x = relay.ty.IncompleteType() + a = relay.ty.TypeVar('a') b = relay.ty.TypeVar('b') - - ft1 = relay.ty.FuncType([t1], t2) - ft2 = relay.ty.FuncType([a], b, [a, b]) + c = relay.ty.TypeVar('c') + d = relay.ty.TypeVar('d') + e = relay.ty.TypeVar('e') + f = relay.ty.TypeVar('f') + + ft1 = relay.ty.FuncType([a, b], relay.TupleType([b, c]), [a, b, c]) + ft2 = relay.ty.FuncType([d, e], relay.TupleType([e, x]), [d, e, f]) unified = solver.Unify(ft1, ft2) assert (unified == solver.Resolve(ft1)) @@ -181,6 +184,22 @@ def test_backward_solving_after_child_update(): assert solver.Resolve(t4) == tup_concrete assert solver.Resolve(t5) == tup_concrete + +@raises(tvm._ffi.base.TVMError) +def test_unbound_type_var(): + solver = make_solver() + + # should not be able to unify because nothing is known about b and d + a = relay.ty.TypeVar('a') + b = relay.ty.TypeVar('b') + c = relay.ty.TypeVar('c') + d = relay.ty.TypeVar('d') + + ft1 = relay.ty.FuncType([a], b, [a]) + ft2 = relay.ty.FuncType([c], d, [c]) + solver.Unify(ft1, ft2) + + @raises(tvm._ffi.base.TVMError) def test_incompatible_tuple_unification(): solver = make_solver() From 13ecb5305f0a1ad93bcef3ae38a385be2dbc64b0 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 20 Dec 2018 13:36:59 -0800 Subject: [PATCH 32/55] Attempt at keeping type env in inferencer (broken) --- src/relay/pass/type_infer.cc | 54 ++++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 9861aef2f725..b19eb7be8907 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -124,7 +124,7 @@ class Generalizer : public TypeMutator { tvm::Map subst_map_; int varno_; }; - + // // The inference algorithm can roughly be devided into three stages: // - Populate the constraints by visiting the expression (TypeInferencer.GetType) @@ -155,16 +155,20 @@ class TypeInferencer : private ExprFunctor { // type inferencer will populate it up std::unordered_map type_map_; + // map from type vars to instantiations + tvm::Map type_env_; + // The solver used by the inferencer. TypeSolver solver_; // relation function TypeRelationFn tuple_getitem_rel_; - TypeRelationFn make_tuple_rel_; // Unify two types Type Unify(const Type& t1, const Type& t2, const Span& span) { // TODO(tqchen, jroesch): propagate span to solver try { - return solver_.Unify(t1, t2); + Type inst_t1 = Instantiate(t1); + Type inst_t2 = Instantiate(t2); + return solver_.Unify(inst_t1, inst_t2); } catch (const dmlc::Error &e) { LOG(FATAL) << "Error unifying `" @@ -182,7 +186,7 @@ class TypeInferencer : private ExprFunctor { if (it != type_map_.end() && it->second.checked_type.defined()) { return it->second.checked_type; } - Type ret = this->VisitExpr(expr); + Type ret = Instantiate(this->VisitExpr(expr)); ResolvedTypeInfo& rti = type_map_[expr]; rti.checked_type = ret; return ret; @@ -289,6 +293,25 @@ class TypeInferencer : private ExprFunctor { return rtype; } + // instantiates the type in the current type env + Type Instantiate(const Type& t) { + if (!t.defined()) { + return t; + } + auto* ft = t.as(); + if (ft == nullptr) { + return Bind(t, type_env_); + } + + // strip type params before binding, then restore them + auto strip_tvs = FuncTypeNode::make(ft->arg_types, ft->ret_type, + {}, ft->type_constraints); + auto* bound = Bind(strip_tvs, type_env_).as(); + CHECK(bound != nullptr); + return FuncTypeNode::make(bound->arg_types, bound->ret_type, + ft->type_params, bound->type_constraints); + } + // instantiate the function type with fresh FuncType Instantiate(const FuncTypeNode* fn_ty, Array* ty_args, const Span& span) { tvm::Map subst_map; @@ -347,7 +370,12 @@ class TypeInferencer : private ExprFunctor { // with an unknown return type if (inc_ty_node != nullptr) { Type ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); - Type func_type = FuncTypeNode::make(arg_types, ret_type, {}, {}); + Array type_params; + for (auto type_arg : call->type_args) { + TypeVar tv = TypeVarNode::make("placeholder", TypeVarNode::Kind::kType); + type_params.push_back(tv); + } + Type func_type = FuncTypeNode::make(arg_types, ret_type, type_params, {}); Type unified = this->Unify(ftype, func_type, call->span); fn_ty_node = unified.as(); } @@ -405,15 +433,13 @@ class TypeInferencer : private ExprFunctor { Type VisitExpr_(const FunctionNode* f) final { solver_.Solve(); - Array incomplete_arg_types; - for (auto param : f->params) { - incomplete_arg_types.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType)); - } - FuncType incompleteFuncType = - FuncTypeNode::make(incomplete_arg_types, - IncompleteTypeNode::make(TypeVarNode::Kind::kType), - {}, {}); + // instantiate all type vars first, then assemble rest of type + for (auto type_param : f->type_params) { + auto fresh = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + type_env_.Set(type_param, fresh); + } + Array candidate_arg_types; for (auto param : f->params) { candidate_arg_types.push_back(GetType(param)); @@ -426,7 +452,7 @@ class TypeInferencer : private ExprFunctor { rtype, f->type_params, {}); - return Unify(incompleteFuncType, candidateFuncType, f->span); + return candidateFuncType; } }; From 6cb0ba21d63ccd3e74c7f8809cfa94356967e648 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 20 Dec 2018 17:21:19 -0800 Subject: [PATCH 33/55] Revert "Attempt at keeping type env in inferencer (broken)" This reverts commit fd3f56a9d01a613a0eaed3334cf223dd849849cc. --- src/relay/pass/type_infer.cc | 54 ++++++++++-------------------------- 1 file changed, 14 insertions(+), 40 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index b19eb7be8907..9861aef2f725 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -124,7 +124,7 @@ class Generalizer : public TypeMutator { tvm::Map subst_map_; int varno_; }; - + // // The inference algorithm can roughly be devided into three stages: // - Populate the constraints by visiting the expression (TypeInferencer.GetType) @@ -155,20 +155,16 @@ class TypeInferencer : private ExprFunctor { // type inferencer will populate it up std::unordered_map type_map_; - // map from type vars to instantiations - tvm::Map type_env_; - // The solver used by the inferencer. TypeSolver solver_; // relation function TypeRelationFn tuple_getitem_rel_; + TypeRelationFn make_tuple_rel_; // Unify two types Type Unify(const Type& t1, const Type& t2, const Span& span) { // TODO(tqchen, jroesch): propagate span to solver try { - Type inst_t1 = Instantiate(t1); - Type inst_t2 = Instantiate(t2); - return solver_.Unify(inst_t1, inst_t2); + return solver_.Unify(t1, t2); } catch (const dmlc::Error &e) { LOG(FATAL) << "Error unifying `" @@ -186,7 +182,7 @@ class TypeInferencer : private ExprFunctor { if (it != type_map_.end() && it->second.checked_type.defined()) { return it->second.checked_type; } - Type ret = Instantiate(this->VisitExpr(expr)); + Type ret = this->VisitExpr(expr); ResolvedTypeInfo& rti = type_map_[expr]; rti.checked_type = ret; return ret; @@ -293,25 +289,6 @@ class TypeInferencer : private ExprFunctor { return rtype; } - // instantiates the type in the current type env - Type Instantiate(const Type& t) { - if (!t.defined()) { - return t; - } - auto* ft = t.as(); - if (ft == nullptr) { - return Bind(t, type_env_); - } - - // strip type params before binding, then restore them - auto strip_tvs = FuncTypeNode::make(ft->arg_types, ft->ret_type, - {}, ft->type_constraints); - auto* bound = Bind(strip_tvs, type_env_).as(); - CHECK(bound != nullptr); - return FuncTypeNode::make(bound->arg_types, bound->ret_type, - ft->type_params, bound->type_constraints); - } - // instantiate the function type with fresh FuncType Instantiate(const FuncTypeNode* fn_ty, Array* ty_args, const Span& span) { tvm::Map subst_map; @@ -370,12 +347,7 @@ class TypeInferencer : private ExprFunctor { // with an unknown return type if (inc_ty_node != nullptr) { Type ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); - Array type_params; - for (auto type_arg : call->type_args) { - TypeVar tv = TypeVarNode::make("placeholder", TypeVarNode::Kind::kType); - type_params.push_back(tv); - } - Type func_type = FuncTypeNode::make(arg_types, ret_type, type_params, {}); + Type func_type = FuncTypeNode::make(arg_types, ret_type, {}, {}); Type unified = this->Unify(ftype, func_type, call->span); fn_ty_node = unified.as(); } @@ -433,13 +405,15 @@ class TypeInferencer : private ExprFunctor { Type VisitExpr_(const FunctionNode* f) final { solver_.Solve(); - - // instantiate all type vars first, then assemble rest of type - for (auto type_param : f->type_params) { - auto fresh = IncompleteTypeNode::make(TypeVarNode::Kind::kType); - type_env_.Set(type_param, fresh); + Array incomplete_arg_types; + for (auto param : f->params) { + incomplete_arg_types.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType)); } - + FuncType incompleteFuncType = + FuncTypeNode::make(incomplete_arg_types, + IncompleteTypeNode::make(TypeVarNode::Kind::kType), + {}, {}); + Array candidate_arg_types; for (auto param : f->params) { candidate_arg_types.push_back(GetType(param)); @@ -452,7 +426,7 @@ class TypeInferencer : private ExprFunctor { rtype, f->type_params, {}); - return candidateFuncType; + return Unify(incompleteFuncType, candidateFuncType, f->span); } }; From c2088f3c90a674ec20180a3d7921511279cba136 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 20 Dec 2018 17:22:27 -0800 Subject: [PATCH 34/55] Revert "Unifier should not keep additional state, ensure that type params *match* in func types" This reverts commit d6f0716d1d8e4a2da84e573d835d3c9cfe585abc. --- include/tvm/relay/type.h | 34 ++++++++++ src/relay/pass/type_solver.cc | 89 +++++++++++++++++--------- src/relay/pass/type_solver.h | 2 + tests/python/relay/test_type_solver.py | 33 ++-------- 4 files changed, 101 insertions(+), 57 deletions(-) diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 69a8a4fb0bd7..ad62d1d78571 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -317,6 +317,40 @@ class TypeReporter : public NodeRef { using ContainerType = TypeReporterNode; }; +class TypeUnifier; +/*! + * \brief Data structure that handles type unification. + */ +class TypeUnifierNode : public Node { + public: + /*! + * \brief Returns a unified type based on the two arguments. + */ + TVM_DLL virtual Type Unify(const Type& dst, const Type& src) = 0; + + // unifier is not serializable. + void VisitAttrs(tvm::AttrVisitor* v) final {} + + static constexpr const char* _type_key = "relay.TypeUnifier"; + TVM_DECLARE_NODE_TYPE_INFO(TypeUnifierNode, Node); +}; + +/*! + * \brief Container class of TypeUnifier. + * \sa TypeUnifierNode + */ +class TypeUnifier : public NodeRef { + public: + TypeUnifier() {} + explicit TypeUnifier(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) { + } + TypeUnifierNode* operator->() const { + return static_cast(node_.get()); + } + using ContainerType = TypeUnifierNode; +}; + + /*! * \brief User defined type constraint function. * diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index db79c30f49de..ce056a678b28 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -61,15 +61,18 @@ class TypeSolver::OccursChecker : public TypeVisitor { bool found_; }; -class TypeSolver::Unifier : public TypeFunctor { +class TypeSolver::Unifier : + public TypeUnifierNode, public TypeFunctor { public: - explicit Unifier(TypeSolver* solver) : solver_(solver) {} + explicit Unifier(TypeSolver* solver) : solver_(solver), tv_map_({}) {} - Type Unify(const Type& src, const Type& dst) { + Type Unify(const Type& src, const Type& dst) final { // Known limitation // - handle shape pattern matching - TypeNode* lhs = solver_->GetTypeNode(dst); - TypeNode* rhs = solver_->GetTypeNode(src); + Type new_src = InstantiateTypeVar(src); + Type new_dst = InstantiateTypeVar(dst); + TypeNode* lhs = solver_->GetTypeNode(new_dst); + TypeNode* rhs = solver_->GetTypeNode(new_src); // do occur check so we don't create self-referencing structure if (lhs->FindRoot() == rhs->FindRoot()) { @@ -99,7 +102,7 @@ class TypeSolver::Unifier : public TypeFunctor { } } - // Checks whether lhs (taken to be a type hole) occurs in t, meaning + // Checks whether lhs (taken to be a type var) occurs in t, meaning // there is a recursive equality constraint, which should be rejected. // N.b.: A tautology like ?a = ?a is okay and should be checked for // *before* calling this method @@ -108,6 +111,45 @@ class TypeSolver::Unifier : public TypeFunctor { return rc.Check(t); } + // if t is a type var, replace with an incomplete type + Type InstantiateTypeVar(const Type& t) { + auto* tvn = t.as(); + if (tvn == nullptr) { + return t; + } + + TypeVar tv = GetRef(tvn); + auto it = tv_map_.find(tv); + if (it != tv_map_.end()) { + return (*it).second; + } + + IncompleteType hole = IncompleteTypeNode::make(tv->kind); + tv_map_.Set(tv, hole); + return hole; + } + + // instantiate away all type parameters in a function type + FuncType InstantiateFuncType(const FuncType& ft) { + for (auto type_param : ft->type_params) { + InstantiateTypeVar(type_param); + } + + // to avoid error when substituting type vars (TypeMutator + // errors out if the type var list in a FuncType contains an + // IncompleteType) + FuncType strip_tvs = FuncTypeNode::make(ft->arg_types, + ft->ret_type, {}, + ft->type_constraints); + Type transformed = Bind(strip_tvs, tv_map_); + + auto* new_ft = transformed.as(); + // drop the type param list altogether + return FuncTypeNode::make(new_ft->arg_types, + new_ft->ret_type, {}, + new_ft->type_constraints); + } + // default: unify only if alpha-equal Type VisitTypeDefault_(const Node* op, const Type& tn) override { NodeRef nr = GetRef(op); @@ -139,51 +181,37 @@ class TypeSolver::Unifier : public TypeFunctor { const auto* ftn = tn.as(); if (!ftn || op->arg_types.size() != ftn->arg_types.size() - || op->type_params.size() != ftn->type_params.size() || op->type_constraints.size() != ftn->type_constraints.size()) { return Type(nullptr); } - FuncType ft1 = GetRef(op); - FuncType ft2 = GetRef(ftn); - - // saves work if we can avoid remapping, etc. - if (AlphaEqual(ft1, ft2)) { - return ft1; - } - - // remap bound type params so we can compare by equality - Map subst_map; - for (size_t i = 0; i < ft1->type_params.size(); i++) { - subst_map.Set(ft2->type_params[i], ft1->type_params[i]); - } - - auto bound_ft2 = GetRef(Bind(ft2, subst_map).as()); + FuncType ft1 = InstantiateFuncType(GetRef(op)); + FuncType ft2 = InstantiateFuncType(GetRef(ftn)); - // now unify each field of the func types - Type ret_type = Unify(ft1->ret_type, bound_ft2->ret_type); + Type ret_type = Unify(ft1->ret_type, ft2->ret_type); std::vector arg_types; for (size_t i = 0; i < ft1->arg_types.size(); i++) { - Type arg_type = Unify(ft1->arg_types[i], bound_ft2->arg_types[i]); + Type arg_type = Unify(ft1->arg_types[i], ft2->arg_types[i]); arg_types.push_back(arg_type); } std::vector type_constraints; for (size_t i = 0; i < ft1->type_constraints.size(); i++) { Type unified_constraint = Unify(ft1->type_constraints[i], - bound_ft2->type_constraints[i]); + ft2->type_constraints[i]); const auto* tcn = unified_constraint.as(); CHECK(tcn) << "Two type constraints unified into a non-constraint?" - << ft1->type_constraints[i] << " and " << bound_ft2->type_constraints[i]; + << ft1->type_constraints[i] << " and " << ft2->type_constraints[i]; type_constraints.push_back(GetRef(tcn)); } - return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints); + return FuncTypeNode::make(arg_types, ret_type, {}, type_constraints); } private: TypeSolver* solver_; + Map tv_map_; }; class TypeSolver::Resolver : public TypeMutator { @@ -265,7 +293,7 @@ class TypeSolver::Propagator : public TypeFunctor { // constructor TypeSolver::TypeSolver() - : reporter_(make_node(this)) { + : reporter_(make_node(this)), unifier_(make_node(this)) { } // destructor @@ -281,8 +309,7 @@ TypeSolver::~TypeSolver() { // Add equality constraint Type TypeSolver::Unify(const Type& dst, const Type& src) { - Unifier unifier(this); - return unifier.Unify(dst, src); + return unifier_->Unify(dst, src); } // Add type constraint to the solver. diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index e756440452d5..c6d5c20d737d 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -136,6 +136,8 @@ class TypeSolver { common::Arena arena_; /*! \brief Reporter that reports back to self */ TypeReporter reporter_; + /*! \brief Data structure for unifying types */ + TypeUnifier unifier_; /*! * \brief GetTypeNode that is corresponds to t. * if it do not exist, create a new one. diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index 12546e6b25b1..3ebe4bdfe425 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -116,20 +116,17 @@ def test_unify_vars_under_tuples(): assert (unified == tup1 or unified == tup2) -def test_unify_functions_with_typevars(): +def test_instantiation_of_typevars(): solver = make_solver() - x = relay.ty.IncompleteType() - + t1 = relay.ty.IncompleteType() + t2 = relay.ty.IncompleteType() + a = relay.ty.TypeVar('a') b = relay.ty.TypeVar('b') - c = relay.ty.TypeVar('c') - d = relay.ty.TypeVar('d') - e = relay.ty.TypeVar('e') - f = relay.ty.TypeVar('f') - - ft1 = relay.ty.FuncType([a, b], relay.TupleType([b, c]), [a, b, c]) - ft2 = relay.ty.FuncType([d, e], relay.TupleType([e, x]), [d, e, f]) + + ft1 = relay.ty.FuncType([t1], t2) + ft2 = relay.ty.FuncType([a], b, [a, b]) unified = solver.Unify(ft1, ft2) assert (unified == solver.Resolve(ft1)) @@ -184,22 +181,6 @@ def test_backward_solving_after_child_update(): assert solver.Resolve(t4) == tup_concrete assert solver.Resolve(t5) == tup_concrete - -@raises(tvm._ffi.base.TVMError) -def test_unbound_type_var(): - solver = make_solver() - - # should not be able to unify because nothing is known about b and d - a = relay.ty.TypeVar('a') - b = relay.ty.TypeVar('b') - c = relay.ty.TypeVar('c') - d = relay.ty.TypeVar('d') - - ft1 = relay.ty.FuncType([a], b, [a]) - ft2 = relay.ty.FuncType([c], d, [c]) - solver.Unify(ft1, ft2) - - @raises(tvm._ffi.base.TVMError) def test_incompatible_tuple_unification(): solver = make_solver() From 8ac65709480288fffc089a9d0f515a07e6f09895 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 20 Dec 2018 17:41:45 -0800 Subject: [PATCH 35/55] Remove generalization and tests related to it for now; it needs to be redesigned --- src/relay/pass/type_infer.cc | 63 +-------------- tests/cpp/relay_pass_type_infer_test.cc | 16 ++-- tests/python/relay/test_type_infer.py | 102 ------------------------ 3 files changed, 7 insertions(+), 174 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 9861aef2f725..3501a970afa6 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -72,59 +72,6 @@ struct ResolvedTypeInfo { Array type_args = Array(NodePtr(nullptr)); }; -// Converts incomplete types remaining in function signature to type vars -class Generalizer : public TypeMutator { - public: - Generalizer() : subst_map_({}), varno_(0) {} - - // turns each distinct incomplete type into a type var and returns - // the transformed type with an array of all type vars present - Type Generalize(const Type &t) { - Type ret = VisitType(t); - - auto* ftn = ret.as(); - if (ftn == nullptr) { - return ret; - } - - // for a func type, we generalize at the top level - Array free_vars = FreeTypeVars(GetRef(ftn)); - return FuncTypeNode::make(ftn->arg_types, ftn->ret_type, free_vars, ftn->type_constraints); - } - - Type VisitType_(const IncompleteTypeNode *op) override { - IncompleteType t = GetRef(op); - auto it = subst_map_.find(t); - if (it != subst_map_.end()) { - return (*it).second; - } - - // generate a new type var, add to list - std::stringstream ss; - ss << "_var_" << varno_; - varno_++; - TypeVar new_var = TypeVarNode::make(ss.str(), TypeVarNode::Kind::kType); - subst_map_.Set(t, new_var); - return new_var; - } - - Type VisitType_(const FuncTypeNode *op) override { - // drop type params, only do it at the top level - Array arg_types; - for (auto arg_type : op->arg_types) { - arg_types.push_back(this->VisitType(arg_type)); - } - - Type ret_type = this->VisitType(op->ret_type); - - return FuncTypeNode::make(arg_types, ret_type, {}, op->type_constraints); - } - - private: - tvm::Map subst_map_; - int varno_; -}; - // // The inference algorithm can roughly be devided into three stages: // - Populate the constraints by visiting the expression (TypeInferencer.GetType) @@ -149,8 +96,6 @@ class TypeInferencer : private ExprFunctor { class Resolver; // internal environment Module mod_; - // Generalizer for handling let nodes - Generalizer gen_; // map from expression to checked type // type inferencer will populate it up std::unordered_map type_map_; @@ -243,10 +188,6 @@ class TypeInferencer : private ExprFunctor { } Type vtype = GetType(op->value); - // need to generalize inner functions immediately (per H-M) - if (isFunctionLiteral) { - vtype = gen_.Generalize(vtype); - } if (op->var->type_annotation.defined()) { vtype = Unify(vtype, op->var->type_annotation, op->span); } @@ -434,7 +375,7 @@ class TypeInferencer::Resolver : public ExprMutator { public: Resolver(const std::unordered_map& tmap, TypeSolver* solver) - : tmap_(tmap), solver_(solver), gen_(Generalizer()) { + : tmap_(tmap), solver_(solver) { } Expr VisitExpr_(const VarNode* op) final { @@ -483,7 +424,6 @@ class TypeInferencer::Resolver : public ExprMutator { auto it = tmap_.find(GetRef(op)); CHECK(it != tmap_.end()); Type checked_type = solver_->Resolve(it->second.checked_type); - checked_type = gen_.Generalize(checked_type); CHECK(checked_type.as() == nullptr) << "Cannot resolve type of " << GetRef(op) << " at " << op->span; @@ -569,7 +509,6 @@ class TypeInferencer::Resolver : public ExprMutator { private: const std::unordered_map& tmap_; TypeSolver* solver_; - Generalizer gen_; // whether attach the checked type as type_annotation // if original type anntation is missing. bool update_missing_type_annotation_{true}; diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index cc073f28f56f..50aed4c57338 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -6,20 +6,16 @@ TEST(Relay, SelfReference) { using namespace tvm; - auto type_a = relay::TypeVarNode::make("a", relay::TypeVarNode::kType); - auto type_b = relay::TypeVarNode::make("b", relay::TypeVarNode::kType); - auto x = relay::VarNode::make("x", type_a); - auto f = relay::FunctionNode::make(tvm::Array{ x }, x, type_b, - Array{type_a, type_b}); + auto tensor_type = relay::TensorTypeNode::make({}, ::tvm::Bool()); + auto x = relay::VarNode::make("x", relay::Type()); + auto f = relay::FunctionNode::make(tvm::Array{ x }, x, relay::Type(), {}); - auto y = relay::VarNode::make("y", type_a); + auto y = relay::VarNode::make("y", tensor_type); auto call = relay::CallNode::make(f, Array{ y }); - auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, type_b, - tvm::Array{type_a, type_b}); + auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, relay::Type(), {}); auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map{})); - auto expected = relay::FuncTypeNode::make(tvm::Array{ type_a }, type_a, - tvm::Array{type_a} , {}); + auto expected = relay::FuncTypeNode::make(tvm::Array{ tensor_type }, tensor_type, {}, {}); CHECK(AlphaEqual(type_fx->checked_type(), expected)); } diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index ef7a5249baf2..63e1479dee37 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -146,18 +146,6 @@ def test_call_with_type_args(): assert ft.checked_type == relay.FuncType([], unit_type) -def test_generalized_call(): - x = relay.var('x') - f = relay.var('f') - func = relay.Function([x, f], relay.Call(f, [x])) - - a = relay.TypeVar('a') - b = relay.TypeVar('b') - - ft = relay.ir_pass.infer_type(func) - assert ft.checked_type == relay.FuncType([a, relay.FuncType([a], b)], b, [a, b]) - - def test_tuple(): tp = relay.TensorType((10,)) x = relay.var("x", tp) @@ -166,23 +154,6 @@ def test_tuple(): relay.TupleType([tp, tp])) -def test_generalized_tuple(): - x = relay.var('x') - y = relay.var('y') - z = relay.var('z') - - func = relay.Function([x, y, z], relay.Tuple([x, y, z])) - - a = relay.TypeVar('a') - b = relay.TypeVar('b') - c = relay.TypeVar('c') - ft = relay.ir_pass.infer_type(func) - assert ft.checked_type == relay.FuncType( - [a, b, c], - relay.TupleType([a, b, c]), - [a, b, c]) - - def test_free_expr(): x = relay.var("x", "float32") y = relay.add(x, x) @@ -208,79 +179,6 @@ def test_type_args(): assert sh2[1].value == 10 -def test_self_reference(): - """ - Program: - def f(x) { - return x; - } - """ - a = relay.TypeVar("a") - b = relay.TypeVar("b") - x = relay.var("x", a) - y = relay.var("y", a) - sb = relay.ScopeBuilder() - - f = relay.Function([x], x, b, [a, b]) - fx = relay.Function([y], relay.Call(f, [y])) - - x_type = relay.ir_pass.infer_type(x).checked_type - f_type = relay.ir_pass.infer_type(f).checked_type - call_type = relay.ir_pass.infer_type(fx).checked_type - assert f_type == relay.FuncType([a], a, [a]) - assert call_type == relay.FuncType([a], a, [a]) - - -def test_nested_recursive_function(): - """ - Program: - def f(x) { - let g = fun(x) { g(x) }; - g(x) - } - """ - x = relay.var("x") - y = relay.var("y") - g = relay.var("g") - f = relay.Function([x], - relay.Let(g, - relay.Function( - [y], relay.Call(g, [y])), - relay.Call(g, [x]))) - - a = relay.TypeVar("a") - b = relay.TypeVar("b") - f_type = relay.ir_pass.infer_type(f).checked_type - assert f_type == relay.FuncType([a], b, [a, b]) - - -def test_proper_inner_function_generalization(): - """ - Program: - def f() { - let id = fun(x) { x }; - let unit = id(()); - let idid = id(id); - unit - } - """ - x = relay.var("x") - unit = relay.var("unit") - id1 = relay.var("id") - id2 = relay.var("idid") - f = relay.Function( - [], - relay.Let(id1, relay.Function([x], x), - relay.Let( - unit, relay.Call(id1, [relay.Tuple([])]), - relay.Let( - id2, relay.Call(id1, [id1]), - unit)))) - - f_type = relay.ir_pass.infer_type(f).checked_type - assert f_type == relay.FuncType([], relay.TupleType([])) - - def test_global_var_recursion(): mod = relay.Module({}) gv = relay.GlobalVar("foo") From b4dec4054d0b95177748981bc6f1a87f515aaaab Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 21 Dec 2018 22:28:36 -0800 Subject: [PATCH 36/55] Rename method Bounded() to MarkBounded() --- src/relay/pass/util.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 03b72831e239..403863c1d757 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -156,7 +156,7 @@ class VarVisitor : protected ExprVisitor { return ret; } - void Bounded(const Var& v) { + void MarkBounded(const Var& v) { bound_vars_.Insert(v); vars_.Insert(v); } @@ -167,13 +167,13 @@ class VarVisitor : protected ExprVisitor { void VisitExpr_(const FunctionNode* op) final { for (const auto& param : op->params) { - Bounded(param); + MarkBounded(param); } VisitExpr(op->body); } void VisitExpr_(const LetNode* op) final { - Bounded(op->var); + MarkBounded(op->var); VisitExpr(op->value); VisitExpr(op->body); } From b00eb4ff3119721db76101b6249190b45b481c06 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 21 Dec 2018 22:36:02 -0800 Subject: [PATCH 37/55] Require all type args to be specified or none (however, type arg test broken without generalization) --- src/relay/pass/type_infer.cc | 23 +++++++++++++---------- tests/python/relay/test_type_infer.py | 23 ----------------------- 2 files changed, 13 insertions(+), 33 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 3501a970afa6..49d0e11998b0 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -230,20 +230,14 @@ class TypeInferencer : private ExprFunctor { return rtype; } - // instantiate the function type with fresh - FuncType Instantiate(const FuncTypeNode* fn_ty, Array* ty_args, const Span& span) { + // substitute the type args in the function type + FuncType Instantiate(const FuncTypeNode* fn_ty, const Array& ty_args, const Span& span) { tvm::Map subst_map; // Build a subsitituion map up from the function type and type arguments. // Eventually allow the type vars to be passed in. for (size_t i = 0; i < fn_ty->type_params.size(); i++) { - auto ty_param = fn_ty->type_params[i]; - IncompleteType fresh = IncompleteTypeNode::make(ty_param->kind); - subst_map.Set(ty_param, fresh); - if (i < ty_args->size()) { - this->Unify(fresh, (*ty_args)[i], span); - } - ty_args->push_back(fresh); + subst_map.Set(fn_ty->type_params[i], ty_args[i]); } Type ret_type = fn_ty->ret_type; @@ -294,7 +288,16 @@ class TypeInferencer : private ExprFunctor { } Array type_args = call->type_args; - FuncType fn_ty = Instantiate(fn_ty_node, &type_args, call->span); + if (type_args.size() == 0) { + for (size_t i = 0; i < fn_ty_node->type_params.size(); i++) { + type_args.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType)); + } + } + CHECK(type_args.size() == fn_ty_node->type_params.size()) + << "Incorrect number of type args in " << call->span << ": " + << "Expected " << fn_ty_node->type_params.size() + << "but got " << type_args.size(); + FuncType fn_ty = Instantiate(fn_ty_node, type_args, call->span); AddTypeArgs(GetRef(call), type_args); diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 63e1479dee37..ac4eb1b404db 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -123,29 +123,6 @@ def test_incomplete_call(): assert ft.checked_type == relay.FuncType([tt, f_type], tt) -def test_call_with_type_args(): - a = relay.TypeVar('a') - b = relay.TypeVar('b') - - x = relay.Var('x', a) - f = relay.Var('f', relay.FuncType([a], b)) - func = relay.Function([x, f], relay.Call(f, [x]), b, [a, b]) - - unit_type = relay.TupleType([]) - v = relay.Var('v', unit_type) - concrete_func = relay.Function( - [], - relay.Call( - func, - [relay.Tuple([]), - relay.Function([v], relay.Tuple([]))], - type_args=[unit_type, unit_type]), - unit_type) - - ft = relay.ir_pass.infer_type(concrete_func) - assert ft.checked_type == relay.FuncType([], unit_type) - - def test_tuple(): tp = relay.TensorType((10,)) x = relay.var("x", tp) From 1eb81e40d4fe533181307e7effb9a8cd5ee8c844 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 21 Dec 2018 22:42:33 -0800 Subject: [PATCH 38/55] Whitespace --- src/relay/pass/type_infer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 49d0e11998b0..e0aa072fba5a 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -290,7 +290,7 @@ class TypeInferencer : private ExprFunctor { Array type_args = call->type_args; if (type_args.size() == 0) { for (size_t i = 0; i < fn_ty_node->type_params.size(); i++) { - type_args.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType)); + type_args.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType)); } } CHECK(type_args.size() == fn_ty_node->type_params.size()) From 11bda1c421772162c9a959e594e8179438d484d7 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 21 Dec 2018 22:57:49 -0800 Subject: [PATCH 39/55] Remove redundant func type creation in instantiation in TypeSolver --- src/relay/pass/type_solver.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index ce056a678b28..2e430cf6b1ab 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -145,9 +145,7 @@ class TypeSolver::Unifier : auto* new_ft = transformed.as(); // drop the type param list altogether - return FuncTypeNode::make(new_ft->arg_types, - new_ft->ret_type, {}, - new_ft->type_constraints); + return GetRef(new_ft); } // default: unify only if alpha-equal From 02824df54945478e985ad920da4557e9836c6a9e Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Sat, 22 Dec 2018 11:54:05 -0800 Subject: [PATCH 40/55] Style change (variable rename) --- src/relay/pass/type_infer.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index e0aa072fba5a..320ef763b70d 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -182,8 +182,8 @@ class TypeInferencer : private ExprFunctor { Type VisitExpr_(const LetNode* op) final { // if the definition is a function literal, permit recursion - bool isFunctionLiteral = op->value.as() != nullptr; - if (isFunctionLiteral) { + bool is_functional_literal = op->value.as() != nullptr; + if (is_functional_literal) { type_map_[op->var].checked_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); } @@ -191,7 +191,7 @@ class TypeInferencer : private ExprFunctor { if (op->var->type_annotation.defined()) { vtype = Unify(vtype, op->var->type_annotation, op->span); } - CHECK(isFunctionLiteral || !type_map_.count(op->var)); + CHECK(is_functional_literal || !type_map_.count(op->var)); // NOTE: no scoping is necessary because var are unique in program type_map_[op->var].checked_type = vtype; return GetType(op->body); From f7c14e8f5fab84a3234b6741fb7bac0bfa8965c6 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Sat, 22 Dec 2018 13:36:45 -0800 Subject: [PATCH 41/55] Instantiate type vars in type_infer instead of unifier, do not expose unifier --- include/tvm/relay/type.h | 34 ------------- src/relay/pass/type_infer.cc | 35 +++++++++++++- src/relay/pass/type_solver.cc | 67 +++++++------------------- src/relay/pass/type_solver.h | 2 - tests/python/relay/test_type_solver.py | 6 ++- 5 files changed, 55 insertions(+), 89 deletions(-) diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index ad62d1d78571..69a8a4fb0bd7 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -317,40 +317,6 @@ class TypeReporter : public NodeRef { using ContainerType = TypeReporterNode; }; -class TypeUnifier; -/*! - * \brief Data structure that handles type unification. - */ -class TypeUnifierNode : public Node { - public: - /*! - * \brief Returns a unified type based on the two arguments. - */ - TVM_DLL virtual Type Unify(const Type& dst, const Type& src) = 0; - - // unifier is not serializable. - void VisitAttrs(tvm::AttrVisitor* v) final {} - - static constexpr const char* _type_key = "relay.TypeUnifier"; - TVM_DECLARE_NODE_TYPE_INFO(TypeUnifierNode, Node); -}; - -/*! - * \brief Container class of TypeUnifier. - * \sa TypeUnifierNode - */ -class TypeUnifier : public NodeRef { - public: - TypeUnifier() {} - explicit TypeUnifier(::tvm::NodePtr<::tvm::Node> n) : NodeRef(n) { - } - TypeUnifierNode* operator->() const { - return static_cast(node_.get()); - } - using ContainerType = TypeUnifierNode; -}; - - /*! * \brief User defined type constraint function. * diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 320ef763b70d..6bffcfcb2a1d 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -100,6 +100,10 @@ class TypeInferencer : private ExprFunctor { // type inferencer will populate it up std::unordered_map type_map_; + // used to ensure we don't have free type vars hanging around + // (a temporary measure until we have proper generalization implemented) + Map instantiation_map_; + // The solver used by the inferencer. TypeSolver solver_; // relation function @@ -120,6 +124,35 @@ class TypeInferencer : private ExprFunctor { return Type(); } } + + // this is a temporary measure to ensure all type vars are + // converted into fresh incomplete type vars until + // generalization is properly implemented + Type InstantiateAwayTypeVars(const Type &t) { + if (!t.defined()) { + return t; + } + auto* ft = t.as(); + if (ft == nullptr) { + return Bind(t, instantiation_map_); + } + + for (auto type_param : ft->type_params) { + if (instantiation_map_.find(type_param) != instantiation_map_.end()) { + continue; + } + instantiation_map_.Set(type_param, IncompleteTypeNode::make(TypeVarNode::Kind::kType)); + } + + Type ret_type = ft->ret_type; + if (!ret_type.defined()) { + ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); + } + + auto strip_tvs = FuncTypeNode::make(ft->arg_types, ret_type, {}, ft->type_constraints); + return Bind(strip_tvs, instantiation_map_); + } + // Lazily get type for expr // will call visit to deduce it if it is not in the type_map_ Type GetType(const Expr &expr) { @@ -127,7 +160,7 @@ class TypeInferencer : private ExprFunctor { if (it != type_map_.end() && it->second.checked_type.defined()) { return it->second.checked_type; } - Type ret = this->VisitExpr(expr); + Type ret = InstantiateAwayTypeVars(this->VisitExpr(expr)); ResolvedTypeInfo& rti = type_map_[expr]; rti.checked_type = ret; return ret; diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 2e430cf6b1ab..f6bc149cda49 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -61,18 +61,15 @@ class TypeSolver::OccursChecker : public TypeVisitor { bool found_; }; -class TypeSolver::Unifier : - public TypeUnifierNode, public TypeFunctor { +class TypeSolver::Unifier : public TypeFunctor { public: - explicit Unifier(TypeSolver* solver) : solver_(solver), tv_map_({}) {} + explicit Unifier(TypeSolver* solver) : solver_(solver) {} - Type Unify(const Type& src, const Type& dst) final { + Type Unify(const Type& src, const Type& dst) { // Known limitation // - handle shape pattern matching - Type new_src = InstantiateTypeVar(src); - Type new_dst = InstantiateTypeVar(dst); - TypeNode* lhs = solver_->GetTypeNode(new_dst); - TypeNode* rhs = solver_->GetTypeNode(new_src); + TypeNode* lhs = solver_->GetTypeNode(dst); + TypeNode* rhs = solver_->GetTypeNode(src); // do occur check so we don't create self-referencing structure if (lhs->FindRoot() == rhs->FindRoot()) { @@ -111,43 +108,6 @@ class TypeSolver::Unifier : return rc.Check(t); } - // if t is a type var, replace with an incomplete type - Type InstantiateTypeVar(const Type& t) { - auto* tvn = t.as(); - if (tvn == nullptr) { - return t; - } - - TypeVar tv = GetRef(tvn); - auto it = tv_map_.find(tv); - if (it != tv_map_.end()) { - return (*it).second; - } - - IncompleteType hole = IncompleteTypeNode::make(tv->kind); - tv_map_.Set(tv, hole); - return hole; - } - - // instantiate away all type parameters in a function type - FuncType InstantiateFuncType(const FuncType& ft) { - for (auto type_param : ft->type_params) { - InstantiateTypeVar(type_param); - } - - // to avoid error when substituting type vars (TypeMutator - // errors out if the type var list in a FuncType contains an - // IncompleteType) - FuncType strip_tvs = FuncTypeNode::make(ft->arg_types, - ft->ret_type, {}, - ft->type_constraints); - Type transformed = Bind(strip_tvs, tv_map_); - - auto* new_ft = transformed.as(); - // drop the type param list altogether - return GetRef(new_ft); - } - // default: unify only if alpha-equal Type VisitTypeDefault_(const Node* op, const Type& tn) override { NodeRef nr = GetRef(op); @@ -179,12 +139,19 @@ class TypeSolver::Unifier : const auto* ftn = tn.as(); if (!ftn || op->arg_types.size() != ftn->arg_types.size() + || op->type_params.size() != ftn->type_params.size() || op->type_constraints.size() != ftn->type_constraints.size()) { return Type(nullptr); } - FuncType ft1 = InstantiateFuncType(GetRef(op)); - FuncType ft2 = InstantiateFuncType(GetRef(ftn)); + // remap type vars so they match + Map subst_map; + for (size_t i = 0; i < op->type_params.size(); i++) { + subst_map.Set(ftn->type_params[i], op->type_params[i]); + } + + auto ft1 = GetRef(op); + auto ft2 = Downcast(Bind(GetRef(ftn), subst_map)); Type ret_type = Unify(ft1->ret_type, ft2->ret_type); @@ -209,7 +176,6 @@ class TypeSolver::Unifier : private: TypeSolver* solver_; - Map tv_map_; }; class TypeSolver::Resolver : public TypeMutator { @@ -291,7 +257,7 @@ class TypeSolver::Propagator : public TypeFunctor { // constructor TypeSolver::TypeSolver() - : reporter_(make_node(this)), unifier_(make_node(this)) { + : reporter_(make_node(this)) { } // destructor @@ -307,7 +273,8 @@ TypeSolver::~TypeSolver() { // Add equality constraint Type TypeSolver::Unify(const Type& dst, const Type& src) { - return unifier_->Unify(dst, src); + Unifier unifier(this); + return unifier.Unify(dst, src); } // Add type constraint to the solver. diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index c6d5c20d737d..e756440452d5 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -136,8 +136,6 @@ class TypeSolver { common::Arena arena_; /*! \brief Reporter that reports back to self */ TypeReporter reporter_; - /*! \brief Data structure for unifying types */ - TypeUnifier unifier_; /*! * \brief GetTypeNode that is corresponds to t. * if it do not exist, create a new one. diff --git a/tests/python/relay/test_type_solver.py b/tests/python/relay/test_type_solver.py index 3ebe4bdfe425..1e2fed0af1f8 100644 --- a/tests/python/relay/test_type_solver.py +++ b/tests/python/relay/test_type_solver.py @@ -116,7 +116,7 @@ def test_unify_vars_under_tuples(): assert (unified == tup1 or unified == tup2) -def test_instantiation_of_typevars(): +def test_binding_over_typevars(): solver = make_solver() t1 = relay.ty.IncompleteType() @@ -124,8 +124,10 @@ def test_instantiation_of_typevars(): a = relay.ty.TypeVar('a') b = relay.ty.TypeVar('b') + c = relay.ty.TypeVar('c') + d = relay.ty.TypeVar('d') - ft1 = relay.ty.FuncType([t1], t2) + ft1 = relay.ty.FuncType([t1], t2, [c, d]) ft2 = relay.ty.FuncType([a], b, [a, b]) unified = solver.Unify(ft1, ft2) assert (unified == solver.Resolve(ft1)) From 19aa3c267cd9d1a9573ea3e0188515c26b2d5544 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Sat, 22 Dec 2018 13:40:32 -0800 Subject: [PATCH 42/55] Whitespace fixes and redundant check --- src/relay/pass/type_infer.cc | 9 +++------ src/relay/pass/type_solver.cc | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 6bffcfcb2a1d..005b5bd63a61 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -138,21 +138,18 @@ class TypeInferencer : private ExprFunctor { } for (auto type_param : ft->type_params) { - if (instantiation_map_.find(type_param) != instantiation_map_.end()) { - continue; - } instantiation_map_.Set(type_param, IncompleteTypeNode::make(TypeVarNode::Kind::kType)); } - + Type ret_type = ft->ret_type; if (!ret_type.defined()) { ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); } - + auto strip_tvs = FuncTypeNode::make(ft->arg_types, ret_type, {}, ft->type_constraints); return Bind(strip_tvs, instantiation_map_); } - + // Lazily get type for expr // will call visit to deduce it if it is not in the type_map_ Type GetType(const Expr &expr) { diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index f6bc149cda49..1469286213f7 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -139,7 +139,7 @@ class TypeSolver::Unifier : public TypeFunctor { const auto* ftn = tn.as(); if (!ftn || op->arg_types.size() != ftn->arg_types.size() - || op->type_params.size() != ftn->type_params.size() + || op->type_params.size() != ftn->type_params.size() || op->type_constraints.size() != ftn->type_constraints.size()) { return Type(nullptr); } From b697a590548a0ac7336ec9015bdd38e16a744698 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Sat, 22 Dec 2018 13:55:15 -0800 Subject: [PATCH 43/55] Simplify function literal type inference case --- src/relay/pass/type_infer.cc | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 005b5bd63a61..85f4aba08cdf 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -379,28 +379,16 @@ class TypeInferencer : private ExprFunctor { Type VisitExpr_(const FunctionNode* f) final { solver_.Solve(); - Array incomplete_arg_types; - for (auto param : f->params) { - incomplete_arg_types.push_back(IncompleteTypeNode::make(TypeVarNode::Kind::kType)); - } - FuncType incompleteFuncType = - FuncTypeNode::make(incomplete_arg_types, - IncompleteTypeNode::make(TypeVarNode::Kind::kType), - {}, {}); - - Array candidate_arg_types; + Array arg_types; for (auto param : f->params) { - candidate_arg_types.push_back(GetType(param)); + arg_types.push_back(GetType(param)); } Type rtype = GetType(f->body); if (f->ret_type.defined()) { rtype = this->Unify(f->ret_type, rtype, f->span); } - FuncType candidateFuncType = FuncTypeNode::make(candidate_arg_types, - rtype, - f->type_params, {}); - - return Unify(incompleteFuncType, candidateFuncType, f->span); + auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {}); + return solver_.Resolve(ret); } }; From fa4f5488713dbf73427922ddd38d4906412963ba Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 24 Dec 2018 14:45:37 -0500 Subject: [PATCH 44/55] Style nitpick --- src/relay/pass/type_solver.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 1469286213f7..265eac539e4f 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -103,7 +103,7 @@ class TypeSolver::Unifier : public TypeFunctor { // there is a recursive equality constraint, which should be rejected. // N.b.: A tautology like ?a = ?a is okay and should be checked for // *before* calling this method - bool CheckOccurs(TypeNode *lhs, const Type &t) { + bool CheckOccurs(TypeNode* lhs, const Type& t) { OccursChecker rc(solver_, lhs); return rc.Check(t); } From af32ff435b29dbdf7a30bc5a3d66b5a14b967e9a Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 24 Dec 2018 12:55:25 -0800 Subject: [PATCH 45/55] Don't drop type params in unifier --- src/relay/pass/type_solver.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 265eac539e4f..757f80a3bb1f 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -171,7 +171,7 @@ class TypeSolver::Unifier : public TypeFunctor { type_constraints.push_back(GetRef(tcn)); } - return FuncTypeNode::make(arg_types, ret_type, {}, type_constraints); + return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints); } private: From 755611a70b18cdd2bc8cd87ce3f2245ea92307a8 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 24 Dec 2018 13:05:33 -0800 Subject: [PATCH 46/55] Clean up and better document type var instantiation hack --- src/relay/pass/type_infer.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 85f4aba08cdf..af4cc6607a44 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -125,10 +125,10 @@ class TypeInferencer : private ExprFunctor { } } - // this is a temporary measure to ensure all type vars are - // converted into fresh incomplete type vars until - // generalization is properly implemented - Type InstantiateAwayTypeVars(const Type &t) { + // Substitutes every type var in t with a corresponding incomplete type. + // This is a temporary measure to ensure type vars behave until + // generalization is properly implemented. + Type Instantiate(const Type &t) { if (!t.defined()) { return t; } @@ -157,7 +157,7 @@ class TypeInferencer : private ExprFunctor { if (it != type_map_.end() && it->second.checked_type.defined()) { return it->second.checked_type; } - Type ret = InstantiateAwayTypeVars(this->VisitExpr(expr)); + Type ret = Instantiate(this->VisitExpr(expr)); ResolvedTypeInfo& rti = type_map_[expr]; rti.checked_type = ret; return ret; @@ -261,7 +261,7 @@ class TypeInferencer : private ExprFunctor { } // substitute the type args in the function type - FuncType Instantiate(const FuncTypeNode* fn_ty, const Array& ty_args, const Span& span) { + FuncType InstantiateFuncType(const FuncTypeNode* fn_ty, const Array& ty_args) { tvm::Map subst_map; // Build a subsitituion map up from the function type and type arguments. @@ -327,7 +327,7 @@ class TypeInferencer : private ExprFunctor { << "Incorrect number of type args in " << call->span << ": " << "Expected " << fn_ty_node->type_params.size() << "but got " << type_args.size(); - FuncType fn_ty = Instantiate(fn_ty_node, type_args, call->span); + FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args); AddTypeArgs(GetRef(call), type_args); From a4f5c09ab659b63b5265ccd430d62109cea26d01 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 27 Dec 2018 20:12:55 -0800 Subject: [PATCH 47/55] MergeFromTo gathers rel links recursively --- src/relay/pass/type_solver.cc | 80 +++++++++++++++++++++++++++++++++-- src/relay/pass/type_solver.h | 18 +------- 2 files changed, 79 insertions(+), 19 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 757f80a3bb1f..897f49ba8909 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -92,9 +92,6 @@ class TypeSolver::Unifier : public TypeFunctor { CHECK(resolved.defined()) << "Unable to unify parent types: " << lhs->resolved_type << " and " << rhs->resolved_type; - TypeNode* top = solver_->GetTypeNode(resolved); - solver_->MergeFromTo(lhs, top); - solver_->MergeFromTo(rhs, top); return resolved; } } @@ -255,6 +252,77 @@ class TypeSolver::Propagator : public TypeFunctor { RelationNode* rel_; }; +// similarly, we use TypeFunctor so we can use +// the default visitor case to avoid more overrides +class TypeSolver::Merger : public TypeFunctor { + public: + explicit Merger(TypeSolver* solver) : solver_(solver) {} + + // Merges src node to dst, ensures *all* type relations of all + // child nodes of src are transferred to dst. + void Merge(TypeNode* src, TypeNode* dst) { + if (src == dst) return; + dst_ = dst; + VisitType(src->resolved_type); + // set parent at the end so later calls to GetTypeNode go back to src + src->parent = dst; + } + + // Transfers any relations linked to t to the stored dst. + // Any unresolved relations are added back to the queue, since + // there is now new information + void TransferLinks(const Type& t) { + TypeNode* src = solver_->GetTypeNode(t); + // move the link to the to dst + for (auto* rlink = src->rel_list.head; rlink != nullptr;) { + // store next pointer first before rlink get moved + auto* next = rlink->next; + // if the relation is not yet resolved + // send the relation to the queue + if (!rlink->value->resolved) { + solver_->AddToQueue(rlink->value); + dst_->rel_list.Push(rlink); + } + rlink = next; + } + } + + void VisitTypeDefault_(const Node* op) override { + NodeRef nr = GetRef(op); + Type t = GetRef(nr.as_derived()); + TransferLinks(t); + } + + void VisitType_(const TupleTypeNode* ttn) override { + auto tup = GetRef(ttn); + TransferLinks(tup); + + for (auto field : tup->fields) { + VisitType(field); + } + } + + void VisitType_(const FuncTypeNode* ftn) override { + auto func = GetRef(ftn); + TransferLinks(func); + + VisitType(func->ret_type); + for (auto arg : func->arg_types) { + VisitType(arg); + } + for (auto param : func->type_params) { + VisitType(param); + } + for (auto constraint : func->type_constraints) { + VisitType(constraint); + } + } + + private: + TypeSolver* solver_; + TypeNode* dst_; +}; + // constructor TypeSolver::TypeSolver() : reporter_(make_node(this)) { @@ -271,6 +339,12 @@ TypeSolver::~TypeSolver() { } } +// merge src type node to dst +void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) { + Merger merger(this); + merger.Merge(src, dst); +} + // Add equality constraint Type TypeSolver::Unify(const Type& dst, const Type& src) { Unifier unifier(this); diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index e756440452d5..3e5cede29391 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -70,6 +70,7 @@ class TypeSolver { class Unifier; class Resolver; class Propagator; + class Merger; class Reporter; struct TypeNode; struct RelationNode; @@ -168,22 +169,7 @@ class TypeSolver { * \param src The source operand * \param dst The dst operand. */ - void MergeFromTo(TypeNode* src, TypeNode* dst) { - if (src == dst) return; - src->parent = dst; - // move the link to the to dst - for (auto* rlink = src->rel_list.head; rlink != nullptr;) { - // store next pointer first before rlink get moved - auto* next = rlink->next; - // if the relation is not yet resolved - // send the relation to the queue - if (!rlink->value->resolved) { - this->AddToQueue(rlink->value); - dst->rel_list.Push(rlink); - } - rlink = next; - } - } + void MergeFromTo(TypeNode* src, TypeNode* dst); }; } // namespace relay From 1ee69cc38b253d7d5986120ef58cbe06f1fd3c32 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 27 Dec 2018 21:22:31 -0800 Subject: [PATCH 48/55] Copy links over to avoid circular links --- src/relay/pass/type_solver.cc | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 897f49ba8909..b7fff591250a 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -92,6 +92,9 @@ class TypeSolver::Unifier : public TypeFunctor { CHECK(resolved.defined()) << "Unable to unify parent types: " << lhs->resolved_type << " and " << rhs->resolved_type; + TypeNode* top = solver_->GetTypeNode(resolved); + solver_->MergeFromTo(lhs, top); + solver_->MergeFromTo(rhs, top); return resolved; } } @@ -268,34 +271,35 @@ class TypeSolver::Merger : public TypeFunctor { src->parent = dst; } - // Transfers any relations linked to t to the stored dst. + // Copies any relations linked to t to the stored dst. // Any unresolved relations are added back to the queue, since // there is now new information - void TransferLinks(const Type& t) { + void CopyLinks(const Type& t) { TypeNode* src = solver_->GetTypeNode(t); + if (src == dst_) return; // move the link to the to dst - for (auto* rlink = src->rel_list.head; rlink != nullptr;) { - // store next pointer first before rlink get moved - auto* next = rlink->next; + for (auto* rlink = src->rel_list.head; rlink != nullptr; rlink = rlink->next) { // if the relation is not yet resolved // send the relation to the queue if (!rlink->value->resolved) { solver_->AddToQueue(rlink->value); - dst_->rel_list.Push(rlink); + // copy link to avoid introducing circular references + auto* new_rlink = solver_->arena_.make >(); + new_rlink->value = rlink->value; + dst_->rel_list.Push(new_rlink); } - rlink = next; } } void VisitTypeDefault_(const Node* op) override { NodeRef nr = GetRef(op); Type t = GetRef(nr.as_derived()); - TransferLinks(t); + CopyLinks(t); } void VisitType_(const TupleTypeNode* ttn) override { auto tup = GetRef(ttn); - TransferLinks(tup); + CopyLinks(tup); for (auto field : tup->fields) { VisitType(field); @@ -304,7 +308,7 @@ class TypeSolver::Merger : public TypeFunctor { void VisitType_(const FuncTypeNode* ftn) override { auto func = GetRef(ftn); - TransferLinks(func); + CopyLinks(func); VisitType(func->ret_type); for (auto arg : func->arg_types) { From 8cf9962e4778b947b0d46701846d85d74a5cb898 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 27 Dec 2018 21:57:01 -0800 Subject: [PATCH 49/55] Use set for storing rels, propagate after merging typenodes --- src/relay/pass/type_solver.cc | 36 +++++++++++++++++------------------ src/relay/pass/type_solver.h | 8 ++++---- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index b7fff591250a..feef339af982 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -212,9 +212,7 @@ class TypeSolver::Propagator : public TypeFunctor { void AddRelToList(const Type& t) { TypeNode* tnode = solver_->GetTypeNode(t); - LinkNode* rlink = solver_->arena_.make >(); - rlink->value = rel_; - tnode->rel_list.Push(rlink); + tnode->rel_set.insert(rel_); } void VisitTypeDefault_(const Node* op) override { @@ -269,24 +267,26 @@ class TypeSolver::Merger : public TypeFunctor { VisitType(src->resolved_type); // set parent at the end so later calls to GetTypeNode go back to src src->parent = dst; + + // now propagate any relations to child nodes, since change to + // a child node should update parent too + for (auto* rel : dst->rel_set) { + Propagator prop(solver_, rel); + prop.Propagate(dst->resolved_type); + } } - // Copies any relations linked to t to the stored dst. + // Transfers any relations linked to t to the stored dst. // Any unresolved relations are added back to the queue, since // there is now new information - void CopyLinks(const Type& t) { + void TransferLinks(const Type& t) { TypeNode* src = solver_->GetTypeNode(t); if (src == dst_) return; - // move the link to the to dst - for (auto* rlink = src->rel_list.head; rlink != nullptr; rlink = rlink->next) { - // if the relation is not yet resolved - // send the relation to the queue - if (!rlink->value->resolved) { - solver_->AddToQueue(rlink->value); - // copy link to avoid introducing circular references - auto* new_rlink = solver_->arena_.make >(); - new_rlink->value = rlink->value; - dst_->rel_list.Push(new_rlink); + for (auto* rel : src->rel_set) { + // if the relation is not yet resolved, add to queue + if (!rel->resolved) { + solver_->AddToQueue(rel); + dst_->rel_set.insert(rel); } } } @@ -294,12 +294,12 @@ class TypeSolver::Merger : public TypeFunctor { void VisitTypeDefault_(const Node* op) override { NodeRef nr = GetRef(op); Type t = GetRef(nr.as_derived()); - CopyLinks(t); + TransferLinks(t); } void VisitType_(const TupleTypeNode* ttn) override { auto tup = GetRef(ttn); - CopyLinks(tup); + TransferLinks(tup); for (auto field : tup->fields) { VisitType(field); @@ -308,7 +308,7 @@ class TypeSolver::Merger : public TypeFunctor { void VisitType_(const FuncTypeNode* ftn) override { auto func = GetRef(ftn); - CopyLinks(func); + TransferLinks(func); VisitType(func->ret_type); for (auto arg : func->arg_types) { diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index 3e5cede29391..b4635fdec331 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -83,15 +83,15 @@ class TypeSolver { * that can unifies the same types to the name resolved_type. * * It also contains collection of links to related Relations, - * which is stored in rel_list. + * which is stored in rel_set. */ struct TypeNode { /*! \brief The final resolved type */ Type resolved_type; /*! \brief type node in the union find algorithm */ TypeNode* parent{nullptr}; - /*! \brief list of relations that is related to this type node */ - LinkedList rel_list; + /*! \brief set of relations that is related to this type node */ + std::unordered_set rel_set; /*! * \brief Find the root type node, perform path compression * \return The root type node. @@ -131,7 +131,7 @@ class TypeSolver { size_t num_resolved_rels_{0}; /*! \brief map from type node to types. */ std::unordered_map tmap_; - /*! \breif Internal queue to update the relation */ + /*! \brief Internal queue to update the relation */ std::queue update_queue_; /*! \brief allocator of all the internal node obhect*/ common::Arena arena_; From 0f9d5809972dcd3a5bad30529e09bbab6ea38375 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 27 Dec 2018 22:25:13 -0800 Subject: [PATCH 50/55] Propagator should be able to propagate multiple relations at once --- src/relay/pass/type_solver.cc | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index feef339af982..caea3755b8f9 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -203,27 +203,30 @@ class TypeSolver::Resolver : public TypeMutator { // most of the overrides. class TypeSolver::Propagator : public TypeFunctor { public: - explicit Propagator(TypeSolver* solver, RelationNode* rel) : solver_(solver), rel_(rel) {} + explicit Propagator(TypeSolver* solver, const std::unordered_set* rels) + : solver_(solver), rels_(rels) {} // adds the relation node to t and all child types of t void Propagate(const Type& t) { VisitType(t); } - void AddRelToList(const Type& t) { + void UpdateRelSet(const Type& t) { TypeNode* tnode = solver_->GetTypeNode(t); - tnode->rel_set.insert(rel_); + for (auto* rel : *rels_) { + tnode->rel_set.insert(rel); + } } void VisitTypeDefault_(const Node* op) override { NodeRef nr = GetRef(op); Type t = GetRef(nr.as_derived()); - AddRelToList(t); + UpdateRelSet(t); } void VisitType_(const TupleTypeNode* op) override { TupleType tt = GetRef(op); - AddRelToList(tt); + UpdateRelSet(tt); for (const Type& t : tt->fields) { Propagate(t); @@ -232,7 +235,7 @@ class TypeSolver::Propagator : public TypeFunctor { void VisitType_(const FuncTypeNode* op) override { FuncType ft = GetRef(op); - AddRelToList(ft); + UpdateRelSet(ft); Propagate(ft->ret_type); for (auto arg_type : ft->arg_types) { @@ -250,7 +253,7 @@ class TypeSolver::Propagator : public TypeFunctor { private: TypeSolver* solver_; - RelationNode* rel_; + const std::unordered_set* rels_; }; // similarly, we use TypeFunctor so we can use @@ -268,12 +271,10 @@ class TypeSolver::Merger : public TypeFunctor { // set parent at the end so later calls to GetTypeNode go back to src src->parent = dst; - // now propagate any relations to child nodes, since change to + // now propagate relations to child nodes, since change to // a child node should update parent too - for (auto* rel : dst->rel_set) { - Propagator prop(solver_, rel); - prop.Propagate(dst->resolved_type); - } + Propagator prop(solver_, &dst->rel_set); + prop.Propagate(dst->resolved_type); } // Transfers any relations linked to t to the stored dst. @@ -370,7 +371,8 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint) { tlink->value = tnode; rnode->type_list.Push(tlink); // insert type->relation node - Propagator prop(this, rnode); + std::unordered_set singleton { rnode }; + Propagator prop(this, &singleton); prop.Propagate(tnode->resolved_type); } // add the relation to the working queue. From eb2754df59c724bf2346e9bb2071af9519dc3e66 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 14 Jan 2019 14:03:04 -0800 Subject: [PATCH 51/55] Correct description of AllVars() utility --- include/tvm/relay/pass.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 0d3697a5b507..566d69cc6b0b 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -134,7 +134,7 @@ tvm::Array FreeVars(const Expr& expr); * * \param expr the expression. * - * \return List of free vars, in the PostDFS order in the expression. + * \return List of all vars, in the PostDFS order in the expression. */ tvm::Array AllVars(const Expr& expr); From 02e1f1cbb00ae6175f39fd874bda4a9629e08e8c Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Tue, 15 Jan 2019 13:53:56 -0800 Subject: [PATCH 52/55] Ensure null annotations replaced in AD --- src/relay/pass/gradient.cc | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 601a09b35d1a..3e8e952356e7 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -53,6 +53,7 @@ Type WithGradientType(const Type& t) { // TODO(M.K.): stricter checking auto ty = t.as(); CHECK(ty) << "input should be a function"; + CHECK(ty->ret_type.defined()); return FuncTypeNode::make(ty->arg_types, TupleTypeNode::make({ ty->ret_type, @@ -172,6 +173,14 @@ struct ReverseAD : ExprFunctor { } }; +/*! \brief Checks whether the annotation is defined. + * If the annotation is defined, return it. + * Otherwise, return a type hole. + */ +Type FromAnnotation(const Type& annotation) { + return (annotation.defined()) ? annotation : IncompleteTypeNode::make(TypeVarNode::Kind::kType); +} + Expr FirstOrderGradient(const Expr& re, const Module& mod) { // Currently we first remove any global functions for the first // order case. @@ -207,11 +216,13 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { }); std::vector vt; for (const auto& p : f->params) { - vt.push_back(p->type_annotation); + vt.push_back(FromAnnotation(p->type_annotation)); } + Type ret_type = FromAnnotation(f->ret_type); + return FunctionNode::make(f->params, body, - TupleTypeNode::make({f->ret_type, TupleTypeNode::make({})}), + TupleTypeNode::make({ret_type, TupleTypeNode::make(vt)}), {}); } From 96e4ebb69c07517bb9817239bcdd78728f2b4232 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Tue, 15 Jan 2019 13:56:22 -0800 Subject: [PATCH 53/55] Unnecessary check --- src/relay/pass/gradient.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 3e8e952356e7..1bc2c88cb501 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -53,7 +53,6 @@ Type WithGradientType(const Type& t) { // TODO(M.K.): stricter checking auto ty = t.as(); CHECK(ty) << "input should be a function"; - CHECK(ty->ret_type.defined()); return FuncTypeNode::make(ty->arg_types, TupleTypeNode::make({ ty->ret_type, From 7c1a42f63318ec0cdbfd5dee43e2ed21df236624 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 16 Jan 2019 13:50:32 -0800 Subject: [PATCH 54/55] Leave grad ret type to be inferred iannotations gone --- src/relay/pass/gradient.cc | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 1bc2c88cb501..a0d551ab2daf 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -172,14 +172,6 @@ struct ReverseAD : ExprFunctor { } }; -/*! \brief Checks whether the annotation is defined. - * If the annotation is defined, return it. - * Otherwise, return a type hole. - */ -Type FromAnnotation(const Type& annotation) { - return (annotation.defined()) ? annotation : IncompleteTypeNode::make(TypeVarNode::Kind::kType); -} - Expr FirstOrderGradient(const Expr& re, const Module& mod) { // Currently we first remove any global functions for the first // order case. @@ -213,16 +205,25 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { }); return Pair(res.foward, grad); }); + + // if type annotations are provided, we will construct a ret type; + // otherwise, leave it to be inferred + Type ret_type = Type(); std::vector vt; - for (const auto& p : f->params) { - vt.push_back(FromAnnotation(p->type_annotation)); + bool missing = !f->ret_type.defined(); + for (const auto& p: f->params) { + if (missing || !p->type_annotation.defined()) { + missing = true; + break; + } + vt.push_back(p->type_annotation); + } + + if (!missing) { + ret_type = TupleTypeNode::make({f->ret_type, TupleTypeNode::make(vt)}); } - Type ret_type = FromAnnotation(f->ret_type); - return FunctionNode::make(f->params, - body, - TupleTypeNode::make({ret_type, TupleTypeNode::make(vt)}), - {}); + return FunctionNode::make(f->params, body, ret_type, {}); } TVM_REGISTER_API("relay._ir_pass.first_order_gradient") From c02375cf3d7d0387ec61099974af5999f9431c68 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 16 Jan 2019 13:56:21 -0800 Subject: [PATCH 55/55] lint --- src/relay/pass/gradient.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index a0d551ab2daf..251d7153e4e6 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -211,7 +211,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { Type ret_type = Type(); std::vector vt; bool missing = !f->ret_type.defined(); - for (const auto& p: f->params) { + for (const auto& p : f->params) { if (missing || !p->type_annotation.defined()) { missing = true; break;