[PyTorch] Branching operations#1027
Merged
timmoon10 merged 15 commits intoNVIDIA:mainfrom Aug 10, 2024
Merged
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Collaborator
Author
|
/te-ci pytorch |
Collaborator
Author
|
/te-ci pytorch |
ptrendx
reviewed
Aug 1, 2024
ptrendx
reviewed
Aug 1, 2024
ptrendx
reviewed
Aug 1, 2024
| "are not compatible" | ||
| ) | ||
|
|
||
| # Check output tensor dims |
Member
There was a problem hiding this comment.
I wonder if we need to do this here (same for input) or maybe we could rely on the error checking on the C++ side to minimize CPU overhead?
Collaborator
Author
There was a problem hiding this comment.
I think that would be a good optimization in the future, especially since the linear functional API is used in multiple operations.
ptrendx
reviewed
Aug 1, 2024
ptrendx
reviewed
Aug 1, 2024
ptrendx
reviewed
Aug 1, 2024
ptrendx
reviewed
Aug 1, 2024
ptrendx
reviewed
Aug 1, 2024
ptrendx
reviewed
Aug 1, 2024
Output tensor dtype and device take precedence over weight tensor in linear functional API. Move some index calculation to fuser constructor. Avoid some unnecessary dereferences. Signed-off-by: Tim Moon <tmoon@nvidia.com>
82b83c9 to
2679fbf
Compare
Collaborator
Author
|
/te-ci pytorch |
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Collaborator
Author
|
/te-ci pytorch |
Member
|
Could you comment on how the change from your last commit helped with the unittest failures? The change from list comprehension to the for loop should not change the behavior, right? |
timmoon10
commented
Aug 9, 2024
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Collaborator
Author
|
/te-ci pytorch |
ptrendx
approved these changes
Aug 9, 2024
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR modifies the operation-based API (#707) to support some simple branching behavior: operations can now accept extra tensor inputs and generate extra tensor outputs. This enables fusions like GEMMs with
beta=1:Support for multiple inputs will also be necessary for cross-attention (and SSMs?). Note that we are not planning to support more complicated structures since that will take us down the road of general graph compilers.
Type of change
Changes
beta=1Checklist: