-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[relay][frontend] aten::copy_ support for pytorch #15502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment. Generated by tvm-bot |
|
|
You've got a segfault from your test. |
|
@masahi It's ready for review. Sorry for late response. |
| torch._C._jit_pass_lower_all_tuples(graph) | ||
|
|
||
|
|
||
| def _redirect_inplace_output(graph): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please give an example of what this pass does, by documenting IR before / after this pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok An example added
| return x | ||
|
|
||
| inputs = torch.randn(10, 10) | ||
| verify_model(InplaceCopy(), [inputs]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add more tests, using various tricky examples to make sure that the conversion works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added more test, tried to test this function with various case. Please let me know about any suggestions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks
masahi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, let's try this approach for our first copy_ support.
Although #9375 has been rejected, I tried a different way to support
aten::copy_op.aten::copy_behaves differently from other inplace ops, "pure inplace" way, unlike other inplace nodes' one, which output graph(torch.Graph) still relaying it's output to users so that a DAG can be structed. However,aten::copy_op returns itself, which dangles all of mutations.For example, a torch module like
generates the graph
which returns
%xitself.My approach to handle this problem is:
from_pytorch, insert a pass that redirects output ofaten::copy_(_redirect_inplace_output), after_run_jit_passesis called, in torch level(torch.Graph)aten::copynode, we collect from it's parents to collectaten::selectandaten::slicenodes, to generate indices of source. I referenced pytorch repository, behavior of torch -> onnx conversionI'm not familiar with making a PR to this repository, so please let me know if there is any feedbacks or questions.