-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[ONNX][Relay] Add dynamic unsqueeze / expand_dims op #9039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX][Relay] Add dynamic unsqueeze / expand_dims op #9039
Conversation
* main: [3/10] Moved TIR generation from Python to C++ for CMSIS-NN (apache#8951) Support match pvar with dtype constraint (apache#9016)
mbrookhart
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I get the multiple back to back expand_dims calls.
| "test_unsqueeze_three_axes", | ||
| "test_unsqueeze_two_axes", | ||
| "test_unsqueeze_unsorted_axes", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add these to device-specific skips below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm actually all targets are really slow it seems. The LLVM target takes > 3 minutes for test_unsqueeze_two_axes.
| if isinstance(axis, int): | ||
| return _make.expand_dims(data, axis, num_newaxis) | ||
| if isinstance(axis, Expr): | ||
| # TODO (AndrewZhaoLuo): investigate performance issues with consecutive | ||
| # dynamic expand_dims on non-llvm targets. | ||
| for _ in range(num_newaxis): | ||
| # Dynamic rank is not well supported so we can only increase rank | ||
| # by a static amount (e.g. 1) so we have to do this | ||
| data = _dyn_make.expand_dims(data, axis) | ||
| return data | ||
| raise ValueError(f"Unknown type for axis: {type(axis)}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we do it all at once if we know that num_newaxis is static?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can, done.
| # dynamic expand_dims on non-llvm targets. | ||
| for i in range(num_new_axis): | ||
| axis = relay.TupleGetItem(axes, i) | ||
| # Unpack scalar | ||
| axis = relay.reshape(axis, []) | ||
| axis = relay.If( | ||
| axis >= relay.const(0, "int64"), axis, axis + relay.const(rank_input, "int64") | ||
| ) | ||
| result = _op.expand_dims(result, axis) | ||
| return result | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, I think this should be doable as one call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be doable if we change the interface to accept a list of sorted axis but it'll be quite a bit more complicated.
|
PTAL @mbrookhart |
* main: Add back missing __init__.py to unbreak CI. (apache#9052) [Meta Schedule][M3b] Builder (apache#9044)
* initial dyn unsqueeze example * simplify, properly unpack scalar * basic tests * squish bugs -- assign proper types * working topi * fix things * temp work * fix casting to int64 * shape encoding method for axis * working shape encoding metric * add comment * move to non-rank encoded axis * failing regime * fix * it works! * add test * add comment on shape func * remove unused topi * undo some file changes * more cleanup * newline * clean up * clean up * enable multiple axis tests * move tests to dynamic op * Update docs * add converter * initial dyn unsqueeze example * simplify, properly unpack scalar * basic tests * squish bugs -- assign proper types * working topi * fix things * temp work * fix casting to int64 * shape encoding method for axis * working shape encoding metric * add comment * move to non-rank encoded axis * failing regime * fix * it works! * add test * add comment on shape func * remove unused topi * undo some file changes * more cleanup * newline * clean up * clean up * enable multiple axis tests * move tests to dynamic op * Update docs * add converter * working tests * add test, remove unneeded file * fix things * more lint * more lint * pick things * disable opencl tests * unsqueeze tests * clean up * dyn stuff * add num_newaxis Co-authored-by: Andrew Zhao Luo <andrewzhaoluo@system76-pc.localdomain>
* initial dyn unsqueeze example * simplify, properly unpack scalar * basic tests * squish bugs -- assign proper types * working topi * fix things * temp work * fix casting to int64 * shape encoding method for axis * working shape encoding metric * add comment * move to non-rank encoded axis * failing regime * fix * it works! * add test * add comment on shape func * remove unused topi * undo some file changes * more cleanup * newline * clean up * clean up * enable multiple axis tests * move tests to dynamic op * Update docs * add converter * initial dyn unsqueeze example * simplify, properly unpack scalar * basic tests * squish bugs -- assign proper types * working topi * fix things * temp work * fix casting to int64 * shape encoding method for axis * working shape encoding metric * add comment * move to non-rank encoded axis * failing regime * fix * it works! * add test * add comment on shape func * remove unused topi * undo some file changes * more cleanup * newline * clean up * clean up * enable multiple axis tests * move tests to dynamic op * Update docs * add converter * working tests * add test, remove unneeded file * fix things * more lint * more lint * pick things * disable opencl tests * unsqueeze tests * clean up * dyn stuff * add num_newaxis Co-authored-by: Andrew Zhao Luo <andrewzhaoluo@system76-pc.localdomain>
Adds new dynamic unsqueeze op. There is an issue with performance however. On targets
metalandopenclwhich I tested the ONNX and unit tests take multiple hours to run. For now, I've disabled offending tests on such targets. The performance issues seem to happen when we have multiple consecutive dynamic reshape ops.