Skip to content

Conversation

@nhat-nguyen
Copy link
Contributor

@nhat-nguyen nhat-nguyen commented Jun 17, 2024

Motivation

tt.addptr produces scalar pointers, which are kept intact by the pointer analysis pass. In the latest fused-attention model, block pointers take the result of a tt.addptr as the base pointer. This behaviour is contrary to the current assumption that block pointers always take in the pointers directly from the kernel arguments; it exposes a bug where we ignore the offset of the buffer (because pointers from kernel arguments are assumed to have offset 0).

Additionally, even though we already lower tt.addptr with scalars to a pair of memref.reinterpret_cast and memref.extract_strided_metadata, lowering memref.extract_strided_metadata to work with Microsoft Maia is rather tricky because we don't have primitives to keep track of the current offset of a buffer.

To fix both issues, I have reworked the StructuredToMemref pass to no longer use memref.extract_strided_metadata.

Background

Lowering a sequence of tt.addptr to memref.reinterpret_cast is tricky because memref.reinterpret_cast does not "remember" the offset of the input buffer. Currently, we leverage memref.extract_strided_metadata to carry on the offset and simplify the lowering of tt.addptr in loops.

Ideally, for each result produced by tt.addptr needs to be converted to a pair of memref and index values that keep track of both the buffer and the index. Fortunately, implementing this approach is much simpler with the introduction of the 1->N type conversion infrastructure. This PR removes the usage of memref.extract_strided_metadata and computes the offsets directly.

Technical details

We leverage the 1->N conversion infrastructure to convert tt.addptr for
scalar to memref.reinterpret_cast.

A tt.addptr has the following form:

 %new_ptr = tt.addptr %ptr %offset

where %new_ptr and %ptr have tt.ptr type, and %offset is of index type.

With this form, there can be a chain of tt.addptr where we keep adding
offsets to an existing pointer:

 %ptr_1 = tt.addptr %ptr_0 %offset
 %ptr_2 = tt.addptr %ptr_1 %offset
 %ptr_3 = tt.addptr %ptr_2 %offset

Now, we want to lower each tt.addptr to a memref.reinterpret_cast so that
the pointers can be used by affine.load and affine.store (lowered from
tt.load and tt.store).

A memref.reinterpret_cast op also takes an offset and returns a memref in a
similar fashion to tt.addptr:

  %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes:
  [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1], offset:
  ?>>

However, since the semantic of memref.reinterpret_cast is different,
the following lowering would be incorrect for the sequence of tt.addptr
above:

  %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset]
  %cast_2 = memref.reinterpret_cast %cast_1 to offset [%offset]
  %cast_3 = memref.reinterpret_cast %cast_2 to offset [%offset]

The above sequence is equivalent to:

  %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset]
  %cast_2 = memref.reinterpret_cast %arg0 to offset [%offset]
  %cast_3 = memref.reinterpret_cast %arg0 to offset [%offset]

In other word, memref.reinterpret_cast ignores the current offset of the
input buffer.

Therefore, we have to manually track the offset for each addptr by lowering
to the following form:

 %offset_1 = arith.addi %cst_0 %offset
 %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset_1]

 %offset_2 = arith.addi %offset_1 %offset
 %cast_2 = memref.reinterpret_cast %arg0 to offset [%offset_2]

 %offset_3 = arith.addi %offset_2 %offset
 %cast_3 = memref.reinterpret_cast %arg0 to offset [%offset_3]

Each tt.addptr is lowered to a pair of arith.addi that accumulates the
current offset before using that offset to the reinterpret_cast.

@nhat-nguyen nhat-nguyen changed the title Rework StructuredToMemref pass Rework StructuredToMemref pass to no longer use memref.extract_strided_metadata Jun 18, 2024
@nhat-nguyen nhat-nguyen requested a review from manbearian June 18, 2024 18:29
@nhat-nguyen nhat-nguyen marked this pull request as ready for review June 18, 2024 18:29
Copy link
Member

@manbearian manbearian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good @nhat-nguyen . I left a few comments that you might want to consider.

@nhat-nguyen nhat-nguyen enabled auto-merge (squash) July 3, 2024 20:17
@nhat-nguyen nhat-nguyen requested a review from manbearian July 3, 2024 20:17
@nhat-nguyen nhat-nguyen merged commit 89cdd22 into main Jul 8, 2024
@nhat-nguyen nhat-nguyen deleted the nhat/addptr branch July 8, 2024 16:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants