-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Vulkan] Add VK_NV_cooperative_matrix support
#14770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
VK_NV_cooperative_matrix support
926ecba to
731e361
Compare
|
@Lunderberg can you help review this PR? |
Lunderberg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think everything looks reasonable, though I'm going to need to wrap my head around the changes in CodeGenSPIRV::VisitStmt_(const ForNode*) before commenting on it.
Lunderberg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, it looks good! I have one change to request, to handle cases where a cooperative matrix is allocated within the body of a loop, but otherwise good to go!
| // Loop head | ||
| builder_->StartLabel(head_label); | ||
|
|
||
| // In a normal matmul, the update semantics c += a * b is implemented by load and store to the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, the special handling for cooperative matrices is required because (1) the cooperative matrices have buffer-type semantics in TIR and (2) the cooperative matrices have value-type semantics in SPIR-V. Due to (1), the cooperative matrices may be mutated, unlike a variable in a LetStmt. Due to (2), the mutation requires a OpPhi in order to join the values from the two predecessor branches, unlike in a BufferStore/BufferLoad. Is that understanding correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly right. I really liked your explanation, so I'll update the comment here based on it.
From TIR perspective, cooperative matrices (and also the corresponding concept in CUDA WMMA) are nothing but a normal buffer with a specialized scope. In particular it can have any shape. On the other hand, in SPIRV they have a fixed shape and has value semantics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume the intention here is to scalarize each 16x16 fragment and represent it using a variable. This transformation does not bring performance gain compared to the approach of using a vector of 16x16 fragments and load/store from it since shader compiler can easily does the same scalarization. Also, if you want to unroll the loop, the address computation in each unrolled iteration is not CSEed. So you will see a larger code size with all these non-CSEed address computations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like my claim "Cooperative matrices cannot be loaded from / stored to a buffer" was based on a wrong assumption. I'd still argue that my approach is simpler and cleaner than the load / store based approach, even if there is no other benefit. A con of my approach is that it differs from how we treat vector types - I don't have a good answer for the question "why cooperative matrices should get special handling" other than somewhat subjective opinions (for example, I don't like how we support CUDA WMMA today, which is based on load / store into an array of fragments). One thing that's not clear to me: I believe the spec doesn't say how many bytes each matrix takes up, so I am not sure it is possible to allocate a buffer of and address them. Maybe we don't need to worry about such issues since memory is typed in SPIRV and size / offset are in terms of the number of elements.
I'm happy to change the implementation to be based on a conventional load / store based approach if people prefer that way. I'd love to get @Lunderberg's opinion on this point (see Mei's PR #14817 that implements the same extension in the alternative approach very similar to how our cuda codegen supports WMMA).
I should also remove the term "unroll" in the comment since nothing is really unrolled by codegen. For example here is a good 4k matmul TIR https://gist.github.com/masahi/dc7173d39f4b376884f49019cde24826 and the generated spv https://gist.github.com/masahi/3fd481c8037b1ed55876c89305ec02ef. The two representations mostly map one to one (no more unrolling done on spv). But it does assume that unrolling has been done in a schedule so that elem offset becomes a constant. I admit this is another con of my approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Mei's approach in #14817, being based on the WMMA intrinsics originally developed for CUDA, inherits both pros and cons of CUDA WMMA support in TVM. In particular, it allows the same schedule to be compiled for CUDA and VK, and since CUDA WMMA is supported by our auto tensorization infra, auto tensorization might also work out of the box for VK. That's a big advantage. cc @vinx13 @tqchen
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Being able to reuse existing schedule and auto tensorization for WMMA is a big advantage. I'm wondering if we can reuse wmma intrinsics since we can still use fragment index to identify matrix (if we want to scalarize it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking through, I like Mei's approach as well. It avoids making usage restrictions (e.g. buffers must be accessed through coop matrix primitives or load/store, but not both), and maintains the mutability semantics of both spirv/tir by avoiding long-lived coop matrix values in the generated spir-v.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
| if (op->kind == ForKind::kSerial) { | ||
| // If this is a serial loop, record which C matrices are updated in this loop. | ||
| tir::PostOrderVisit(op->body, [&accum_matrices](const ObjectRef& obj) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this collection loop collect too many buffers in the case of a cache_read/write that occurs inside a loop? If a cooperative buffer's AllocateNode occurs within the body of a loop, then it cannot have a pre-loop value, nor can its value be passed to the next loop iteration.
I think this case would cause an error to occur in the auto mat = builder_->GetCooperativeMatrix that occurs below. The error would be suppressed if an earlier pass has hoisted the allocation to function-scope, but I don't think the codegen should require that to be the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm I haven't thought about such possibilities... this code only collects the matrices that are accumulated in the reduction loop ("C" matrix). I haven't put a deep thought into it, but I don't see how cache_read / write would pose a problem. The C matrices should be allocated outside of the loop where they are accumulated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or, are you imagining a scenario where the matrix to be accumulated is initialized from cooperative_matrix_load, rather than zero initialized inside the kernel before the reduction loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue I'm imagining would occur when the reduction loop occurs within another loop. For example, suppose I took the matrix multiplication unit test, but added a batch dimension (rough sketch below). The cooperative matrix loads wouldn't depend on the batch number for their shapes, and the declaration/allocation of the cooperative matrices could occur within the body of batch_i. However, the loop over batch_i shouldn't have an OpPhi associated with the cooperative matrices, because there is no value of the matrix prior to entering the loop over batch_i.
@T.prim_func
def main(
X: T.Buffer((4, 16, 32), "float16"),
W: T.Buffer((4, 32, 16), "float16"),
compute: T.Buffer((4, 16, 16), "float32"),
):
for batch_i in range(4):
X_shared = T.decl_buffer((16, 32), "float16", scope="shared")
W_shared = T.decl_buffer((32, 16), "float16", scope="shared")
X_shared_cooperative_matrix_nv = T.decl_buffer(
(16, 32), "float16", scope="cooperative_matrix_nv"
)
W_shared_cooperative_matrix_nv = T.decl_buffer(
(32, 16), "float16", scope="cooperative_matrix_nv"
)
compute_cooperative_matrix_nv = T.decl_buffer(
(16, 16), scope="cooperative_matrix_nv"
)
i_0_j_0_fused = T.launch_thread("blockIdx.x", 1)
with T.launch_thread("threadIdx.x", 32) as tx:
T.cooperative_matrix_fill_NV(
compute_cooperative_matrix_nv.data, 0, 16, 16, T.float32(0)
)
# Continues with remainder of matrix multiplication
for k_0 in range(2):
...|
|
||
| size_t num_bytes = op->dtype.bytes() * op->dtype.lanes() * static_cast<uint32_t>(constant_size); | ||
| shared_memory_bytes_used_ += num_bytes; | ||
| } else if (storage_scope.rank == runtime::StorageRank::kCooperativeMatrixNV) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid the over-collection issue, we could track the currently-active buffers with kCooperativeMatrixNV here. If we track which allocations are currently open, then the post_order_visit in the VisitStmt_(ForNode*) could limit its collection to buffers that that are currently in scope.
active_cooperative_matrices_.insert(op->buffer_var);
this->VisitStmt(op->body);
active_cooperative_matrices_.erase(op->buffer_var);|
|
||
| if (op->kind == ForKind::kSerial) { | ||
| // If this is a serial loop, record which C matrices are updated in this loop. | ||
| tir::PostOrderVisit(op->body, [&accum_matrices](const ObjectRef& obj) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For deeply-nested loops, does this cause a performance issue, as it repeats the visit of the loop body once for every nested loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, although in a gpu kernel having such a deep loop nest is rare, I'll polish this.
This PR enables using tensorcore on vulkan target for NV devices. The semantics of the instructions are pretty much the same as WMMA. See https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/NV/SPV_NV_cooperative_matrix.html for the reference.
The script https://github.com/masahi/tensorir-experiment/blob/master/vk_cooperative_matrix_nv/test_4k.py shows the usage of this extension in a realistic matmul schedule. It gets 36 TFLOPs for f16f16f32 matmul on RTX 3070 (peak being 41 TFLOPs). It also demonstrates AutoTVM-style MS tuning for tiling parameters.
Contrary to our WMMA support, which relies on a complicated "fragment index" concept and generating an "array" of fragment and indexing them in the source code (not possible with SPIRV), my solution materializes all cooperative matrices in the codegen and explicitly unroll all instructions on them. Each matrix is identified by an element offset into the buffer with cooperative matrix scope, so an element offset is required to be a constant. I'm not sure if that is a reasonable assumption, but other than that I believe my solution is better than how we are supporting WMMA for CUDA. See
codegen_spirv.ccfor details.@tqchen @Lunderberg @vinx13 @spectrometerHBH