-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relay] change device annotation from post DFS to recursive #6124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
cc @mbrookhart |
mbrookhart
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add a test showing the problem/desired change in behavior?
tmoreau89
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @zhanghaohit for this PR, I second @mbrookhart's request to add a test case.
Thanks @mbrookhart and @tmoreau89 for the suggestion. I've added a test, which would fail in the original code. |
mbrookhart
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to fix lint. A couple of nitpicks here, but overall, I think it looks good now. Kind of wondering why the unit test doesn't fail earlier.
| int dev_type_ = -1; | ||
| int out_dev_type_ = -1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't love making this class state, precisely because you have to do crazy things with maintaining state in your recursive calls. You could do it as a set of recursive arguments, but that kind of requires re-implementing with ExprFunctor...so maybe this is the cleanest solution.
|
|
||
| annotated_expr = annotated() | ||
| expected_expr = expected() | ||
| assert tvm.ir.structural_equal(annotated_expr, expected_expr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious that master passes this check, but fails on line 377. Why doesn't structural equal properly resolve the error in the device copy op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think master fails on line 379, right? the device type of log2 is not correctly marked.
Up to this line, annotated_expr and expected_expr are exactly the same. The device_copy op is inserted correctly. We haven't go through the device propagation yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 You're right, I misread my first test.
mbrookhart
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Flaky test?
|
|
||
| annotated_expr = annotated() | ||
| expected_expr = expected() | ||
| assert tvm.ir.structural_equal(annotated_expr, expected_expr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 You're right, I misread my first test.
|
please trigger the ci again. |
junrushao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM👍
5a47476 to
61a9bc8
Compare
|
@tmoreau89 Could you help merge this PR? Thanks. |
tmoreau89
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
Thank you @zhanghaohit , @mbrookhart, @junrushao1994, @zhiics the PR has been merged. |
This is related to #5840 and split from PR #5842
Originally, device type is propagated based on the post DFS traversed graph, which may not be consistent if the argument order changes. In addition, it may handle some cases wrongly, e.g., the first residual block in Resnet50. The first few layers in Resnet50 are depicted in the following figure (top to bottom is in DFS order). Basically, we want to let all the layers run on FPGA device, except the first and last few layers. In the original device propagation algorithm, based on the post DFS order, the conv2d layers in grey will be propagated with
CPUdevice type as we encountercopy2first, following which the three grey conv2d nodes are marked as the source device type ofcopy2(i.e.,CPU), which is not correct.By change the device annotation behaviour, we can support more complex graph structure.