-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Small refactors and bug fixes. #2281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
c16aac3
Move IsPrimitive to method on FunctionNode
jroesch 8602c6b
Fix typo
jroesch 2cb2b43
Move functor to its own file to depend on op
jroesch 7eda7a3
Export from top-level
jroesch 7db8523
Fix up after expr_functor.py
jroesch 91777df
Use new API
jroesch cf91ad0
Add case to graph codegen
jroesch a0435cd
Fix linting
jroesch 653f19d
Fix again
jroesch File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name | ||
| """The expression functor of Relay.""" | ||
|
|
||
| from .expr import Function, Call, Let, Var, GlobalVar, If, Tuple, TupleGetItem, Constant | ||
| from .op import Op | ||
|
|
||
| class ExprFunctor: | ||
| """ | ||
| An abstract visitor defined over Expr. | ||
|
|
||
| Defines the default dispatch over expressions, and | ||
| implements memoization. | ||
| """ | ||
| def __init__(self): | ||
| self.memo_map = {} | ||
|
|
||
| # pylint: disable=no-else-return | ||
| def visit(self, expr): | ||
| """Apply the visitor to an expression.""" | ||
| found = self.memo_map.get(expr) | ||
| if found: | ||
| return found | ||
|
|
||
| if isinstance(expr, Function): | ||
| res = self.visit_function(expr) | ||
| elif isinstance(expr, Call): | ||
| res = self.visit_call(expr) | ||
| elif isinstance(expr, Let): | ||
| res = self.visit_let(expr) | ||
| elif isinstance(expr, Var): | ||
| res = self.visit_var(expr) | ||
| elif isinstance(expr, GlobalVar): | ||
| res = self.visit_global_var(expr) | ||
| elif isinstance(expr, If): | ||
| res = self.visit_if(expr) | ||
| elif isinstance(expr, Tuple): | ||
| res = self.visit_tuple(expr) | ||
| elif isinstance(expr, TupleGetItem): | ||
| res = self.visit_tuple_getitem(expr) | ||
| elif isinstance(expr, Constant): | ||
| res = self.visit_constant(expr) | ||
| elif isinstance(expr, Op): | ||
| res = self.visit_op(expr) | ||
| else: | ||
| raise Exception("warning unhandled case: {0}".format(type(expr))) | ||
|
|
||
| self.memo_map[expr] = res | ||
|
|
||
| return res | ||
|
|
||
| def visit_function(self, _): | ||
| raise NotImplementedError() | ||
|
|
||
| def visit_let(self, _): | ||
| raise NotImplementedError() | ||
|
|
||
| def visit_call(self, _): | ||
| raise NotImplementedError() | ||
|
|
||
| def visit_var(self, _): | ||
| raise NotImplementedError() | ||
|
|
||
| def visit_type(self, typ): | ||
| return typ | ||
|
|
||
| def visit_if(self, _): | ||
| raise NotImplementedError() | ||
|
|
||
| def visit_tuple(self, _): | ||
| raise NotImplementedError() | ||
|
|
||
| def visit_tuple_getitem(self, _): | ||
| raise NotImplementedError() | ||
|
|
||
| def visit_global_var(self, _): | ||
| raise NotImplementedError() | ||
|
|
||
| def visit_op(self, _): | ||
| raise NotImplementedError() | ||
|
|
||
| def visit_constant(self, _): | ||
| raise NotImplementedError() | ||
|
|
||
|
|
||
| class ExprMutator(ExprFunctor): | ||
| """ | ||
| A functional visitor over Expr. | ||
|
|
||
| The default behavior recursively traverses the AST | ||
| and reconstructs the AST. | ||
| """ | ||
| def visit_function(self, fn): | ||
| new_body = self.visit(fn.body) | ||
| return Function( | ||
| list(fn.params), | ||
| new_body, | ||
| fn.ret_type, | ||
| fn.type_params, | ||
| fn.attrs) | ||
|
|
||
| def visit_let(self, let): | ||
| new_var = self.visit(let.var) | ||
| new_val = self.visit(let.value) | ||
| new_body = self.visit(let.body) | ||
| return Let(new_var, new_val, new_body) | ||
|
|
||
| def visit_call(self, call): | ||
| new_fn = self.visit(call.op) | ||
| new_args = [self.visit(arg) for arg in call.args] | ||
| return Call(new_fn, new_args, call.attrs) | ||
|
|
||
| def visit_var(self, rvar): | ||
| return rvar | ||
|
|
||
| def visit_global_id(self, global_var): | ||
| return global_var | ||
|
|
||
| def visit_if(self, ite): | ||
| return If( | ||
| self.visit(ite.guard), | ||
| self.visit(ite.true_b), | ||
| self.visit(ite.false_b)) | ||
|
|
||
| def visit_tuple(self, tup): | ||
| return Tuple([self.visit(field) for field in tup.fields]) | ||
|
|
||
| def visit_tuple_getitem(self, op): | ||
| tuple_value = self.visit(op.tuple_value) | ||
| if not tuple_value.same_as(op.tuple_value): | ||
| return TupleGetItem(tuple_value, op.index) | ||
| return op | ||
|
|
||
| def visit_global_var(self, gvar): | ||
| return gvar | ||
|
|
||
| def visit_op(self, op): | ||
| return op | ||
|
|
||
| def visit_constant(self, const): | ||
| return const | ||
|
|
||
| def visit_constructor(self, con): | ||
| return con | ||
|
|
||
| def visit_match(self, m): | ||
| return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.pattern]) | ||
|
|
||
| def visit_ref_new(self, r): | ||
| return RefNew(self.visit(r.value)) | ||
|
|
||
| def visit_ref_write(self, r): | ||
| return RefWrite(self.visit(r.ref), self.visit(r.value)) | ||
|
|
||
| def visit_ref_read(self, r): | ||
| return RefRead(self.visit(r.ref)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe replace this line with the new API?
https://github.com/dmlc/tvm/blob/01cfdb7ebcef0abc9282f68c3c4484974a6cc6e1/src/relay/pass/fuse_ops.cc#L702
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I was cherry-picking changes from an old branch 😄 got distracted at TVM conference while working on this PR.