-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relay][Passes] Iterative A-normal Traversals #7374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -341,26 +341,34 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, | |
| Type VisitExpr_(const OpNode* op) final { return op->op_type; } | ||
|
|
||
| Type VisitExpr_(const LetNode* let) final { | ||
| // if the definition is a function literal, permit recursion | ||
| bool is_functional_literal = let->value.as<FunctionNode>() != nullptr; | ||
| Type let_type = IncompleteType(Kind::kType); | ||
|
|
||
| if (is_functional_literal) { | ||
| let_type = GetType(let->var); | ||
| type_map_[let->var].checked_type = let_type; | ||
| } | ||
| auto pre_visit = [this](const LetNode* op) { | ||
| // if the definition is a function literal, permit recursion | ||
| bool is_functional_literal = op->value.as<FunctionNode>() != nullptr; | ||
| Type let_type = IncompleteType(Kind::kType); | ||
|
|
||
| if (is_functional_literal) { | ||
| let_type = this->GetType(op->var); | ||
| this->type_map_[op->var].checked_type = let_type; | ||
| } | ||
|
|
||
| if (let->var->type_annotation.defined()) { | ||
| let_type = Unify(let_type, let->var->type_annotation, let->span); | ||
| } | ||
| if (op->var->type_annotation.defined()) { | ||
| let_type = this->Unify(let_type, op->var->type_annotation, op->span); | ||
| } | ||
|
|
||
| Type vtype = GetType(let->value); | ||
| let_type = Unify(let_type, vtype, let->span); | ||
| Type vtype = this->GetType(op->value); | ||
| let_type = this->Unify(let_type, vtype, op->span); | ||
|
|
||
| ICHECK(is_functional_literal || !type_map_.count(let->var)); | ||
| // NOTE: no scoping is necessary because var are unique in program | ||
| type_map_[let->var].checked_type = let_type; | ||
| return GetType(let->body); | ||
| ICHECK(is_functional_literal || !this->type_map_.count(op->var)); | ||
| // NOTE: no scoping is necessary because var are unique in program | ||
| this->type_map_[op->var].checked_type = let_type; | ||
| }; | ||
| auto post_visit = [this](const LetNode* op) { | ||
| Expr expr = GetRef<Expr>(op); | ||
| this->memo_[expr] = this->GetType(op->body); | ||
| this->type_map_[expr].checked_type = this->memo_[expr]; | ||
| }; | ||
| ExpandANormalForm(let, pre_visit, post_visit); | ||
| return memo_[GetRef<Expr>(let)]; | ||
| } | ||
|
|
||
| Type VisitExpr_(const IfNode* ite) final { | ||
|
|
@@ -603,7 +611,21 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator { | |
|
|
||
| Expr Rewrite_(const CallNode* op, const Expr& post) final { return AttachCheckedType(op, post); } | ||
|
|
||
| Expr VisitExpr_(const LetNode* op) final { return AttachCheckedType(op); } | ||
| Expr VisitExpr_(const LetNode* op) final { | ||
| auto pre_visit = [this](const LetNode* op) { | ||
| this->VisitExpr(op->var); | ||
| this->VisitExpr(op->value); | ||
| }; | ||
| auto post_visit = [this](const LetNode* op) { | ||
| Expr expr = GetRef<Expr>(op); | ||
| Var var = Downcast<Var>(this->VisitExpr(op->var)); | ||
| Expr value = this->VisitExpr(op->value); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to visit var and value again in the post?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We just need to pull the values out of the cache. Instead of maintaining a cache shared by the two lambdas, I'm using the memorization cache in the Mutator. The second time visit is called, it will short circuit and return the previously computed value.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i see. thanks. |
||
| Expr body = this->VisitExpr(op->body); | ||
| this->memo_[expr] = this->AttachCheckedType(op, Let(var, value, body)); | ||
| }; | ||
| ExpandANormalForm(op, pre_visit, post_visit); | ||
| return memo_[GetRef<Expr>(op)]; | ||
| } | ||
|
|
||
| Expr VisitExpr_(const IfNode* op) final { return AttachCheckedType(op); } | ||
|
|
||
|
|
@@ -738,6 +760,7 @@ Expr TypeInferencer::Infer(GlobalVar var, Function function) { | |
| } | ||
|
|
||
| struct AllCheckTypePopulated : MixedModeVisitor { | ||
| using MixedModeVisitor::VisitExpr_; | ||
| void DispatchExprVisit(const Expr& e) { | ||
| if (e.as<OpNode>()) { | ||
| return; | ||
|
|
@@ -751,6 +774,17 @@ struct AllCheckTypePopulated : MixedModeVisitor { | |
| ICHECK(e->checked_type_.defined()) << "Expression: " << e; | ||
| return ExprVisitor::VisitExpr(e); | ||
| } | ||
| void VisitExpr_(const LetNode* op) final { | ||
| auto pre_visit = [this](const LetNode* op) { | ||
| this->VisitExpr(op->var); | ||
| this->VisitExpr(op->value); | ||
| }; | ||
| auto post_visit = [this](const LetNode* op) { | ||
| this->VisitExpr(op->body); | ||
| this->visit_counter_[op] += 1; | ||
| }; | ||
| ExpandANormalForm(op, pre_visit, post_visit); | ||
| } | ||
| }; | ||
|
|
||
| void EnsureCheckedType(const Expr& e) { AllCheckTypePopulated().VisitExpr(e); } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.