-
Notifications
You must be signed in to change notification settings - Fork 83
Rework StructuredToMemref pass to no longer use memref.extract_strided_metadata #140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This reverts commit 9021cbf.
manbearian
approved these changes
Jun 18, 2024
Member
manbearian
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good @nhat-nguyen . I left a few comments that you might want to consider.
manbearian
approved these changes
Jul 8, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
tt.addptrproduces scalar pointers, which are kept intact by the pointer analysis pass. In the latest fused-attention model, block pointers take the result of att.addptras the base pointer. This behaviour is contrary to the current assumption that block pointers always take in the pointers directly from the kernel arguments; it exposes a bug where we ignore the offset of the buffer (because pointers from kernel arguments are assumed to have offset 0).Additionally, even though we already lower
tt.addptrwith scalars to a pair ofmemref.reinterpret_castandmemref.extract_strided_metadata, loweringmemref.extract_strided_metadatato work with Microsoft Maia is rather tricky because we don't have primitives to keep track of the current offset of a buffer.To fix both issues, I have reworked the StructuredToMemref pass to no longer use
memref.extract_strided_metadata.Background
Lowering a sequence of
tt.addptrtomemref.reinterpret_castis tricky becausememref.reinterpret_castdoes not "remember" the offset of the input buffer. Currently, we leveragememref.extract_strided_metadatato carry on the offset and simplify the lowering oftt.addptrin loops.Ideally, for each result produced by
tt.addptrneeds to be converted to a pair ofmemrefandindexvalues that keep track of both the buffer and the index. Fortunately, implementing this approach is much simpler with the introduction of the 1->N type conversion infrastructure. This PR removes the usage ofmemref.extract_strided_metadataand computes the offsets directly.Technical details
We leverage the 1->N conversion infrastructure to convert tt.addptr for
scalar to memref.reinterpret_cast.
A tt.addptr has the following form:
where %new_ptr and %ptr have tt.ptr type, and %offset is of index type.
With this form, there can be a chain of tt.addptr where we keep adding
offsets to an existing pointer:
Now, we want to lower each tt.addptr to a memref.reinterpret_cast so that
the pointers can be used by affine.load and affine.store (lowered from
tt.load and tt.store).
A memref.reinterpret_cast op also takes an offset and returns a memref in a
similar fashion to tt.addptr:
However, since the semantic of memref.reinterpret_cast is different,
the following lowering would be incorrect for the sequence of tt.addptr
above:
The above sequence is equivalent to:
In other word, memref.reinterpret_cast ignores the current offset of the
input buffer.
Therefore, we have to manually track the offset for each addptr by lowering
to the following form:
Each tt.addptr is lowered to a pair of arith.addi that accumulates the
current offset before using that offset to the reinterpret_cast.