-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Unity][Transform] Some Improvements on pass DecomposeOps #14512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity][Transform] Some Improvements on pass DecomposeOps #14512
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
| }); | ||
| } | ||
|
|
||
| Expr SimplifyLayerNorm(const Call& call) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to decompose LayerNorm since it can be implemented more efficiently as a single op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any complication for fusion if we have a composite ops in general? (not specifically about LayerNorm). i.e., Is there a case where we cannot fuse the composite op but we can when we decompose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah here we simplify LayerNorm for now just because it is easy to get its gradient. If it is more efficient as a single op, maybe in the future we can remove this from the pass after we finish the gradient function of the whole LayerNorm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or maybe we can add some arguments for the pass to let the user specify what operators he want to simplify?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As decomposing LayerNorm is used for training and may influence the inference perf, let's only decompose it in training mode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any complication for fusion if we have a composite ops in general?
Yes, we need to be aware of the implementation when setting op type (kOpaque or kInjective). This will cause different fusion results.
sunggg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thank you for the improvement.
Just a few comments.
| }); | ||
| } | ||
|
|
||
| Expr SimplifyLayerNorm(const Call& call) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As decomposing LayerNorm is used for training and may influence the inference perf, let's only decompose it in training mode
6bfccf8 to
a7dca77
Compare
Prior to this PR, we first have a pass named
SimplifyNormInference(#14221), which is used to decomposebatch_norminto simple operators for optimization during inference. And there are some follow-up changes:tensor_to_shapeop.SimplifyNormInferencesupport op simplification during training (Becausebatch_normbehaves differently in inference and training).So this PR polishes the changes in #14282 mainly in the following aspects:
func_nameto specify the function we want to apply the pass (Or let it beNoneif we want to simplify all functions). This is because sometimes we only want to decompose operators in a specified function.Now the code is clear and easy to read. Now we use
if-else if-else if ...pattern to recognize the op we want to decompose. In the future maybe we can introduce a map or register the decomposition policy in the op attribute, just like what we do in theLegalizeOpspass.