-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Codegen] Use CUDA's half2 and nv_bfloat162 intrinsics for vector fp16/bf16 data types #15190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Collaborator
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
Currently, our CUDA codegen would not utilize CUDA's half2 and nv_bfloat162 intrinsics, and calls the scalar operators for each elements in the vector, which is not efficient. This PR improves the CUDA code by emitting half2 and nv_bfloat162 intrinsics when possible, and could potentially makes the generated program running faster (in case that nvcc didn't do this optimization for some programs).
The PR is based on #15183 and will be rebased to mainline after that PR get merged.
Example
Suppose a user is vectorizing the following operation:
Before this PR, TVM would emit the following CUDA code:
After this PR, TVM would emit code that uses half2 instrinsics directly:
cc @Hzfengsy @masahi @tqchen @junrushao @vinx13