Conversation
zasdfgbnm
left a comment
There was a problem hiding this comment.
Good work! LGTM except some minor comments.
Please cherry-pick this to main without the change in csrc/predicate_compute.cpp.
This will add a trivial predicate, but the perf should still be faster than the perf without swizzle.
csrc/scheduler/matmul.h
Outdated
| //! A2 A3 B3 B4 C1 C2 C3 C4 D1 D2 D3 D4 | ||
| //! C1 C2 D1 D2 | ||
| //! C3 C4 D3 D4 | ||
| int swizzle_factor = 1; |
There was a problem hiding this comment.
Should we call this grid_swizzle_factor to distinguish from the prologue "swizzle" that is used to remove bank conflict?
| auto inputs = fp16MatmulAtInput(M, N, K, layout); | ||
|
|
||
| FusionExecutor fe; | ||
| fe.setMeasureKernelTimeFlag(true); |
There was a problem hiding this comment.
I think this will add some debug printting when running this test? Could you comment this line out?
There was a problem hiding this comment.
No debug prints. The only effect is to create cudaEvents and return the runtime through fe.kernelTimeMs() (otherwise it's just zero).
There was a problem hiding this comment.
Ahhh, I thought it would automatically std::cout << fe.kernelTimeMs(). If no debug prints, then we can keep it.
|
|
||
| // Gmem pipeline stage | ||
|
|
||
| for (auto layout : {MatmulLayout::TT}) { |
There was a problem hiding this comment.
Should this be kAllSupportedMatmulLayout?
There was a problem hiding this comment.
Just removed it to keep a short runtime as the test checks four configs per layout already
There was a problem hiding this comment.
Current test takes 15s, would jump to 45s.
There was a problem hiding this comment.
OK, then let's just test one layout
Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
Adds a CTA swizzle to change the order in which the tiles of the output matrix are processed. This swizzle increases data reuse from A and B, when iterating over gridDim.x. Turns out that CTAs are launched in practice by iterating over gridDim.x first (order is unspecified though, it just happens to behave the same). As a result, the current wave will contain CTAs that compute square sub-matrices of C, and so, increase L2 hit rate. Best factor seems to be 4. This will be part of the heuristics. Setting the factor to 1 disables this swizzle. On a 8192x8192x8192 matmul with default config, the speedup is about 20%. An extreme example is following case: `MNK = 6144 6144 6144, layout=NT stages=0, cta_tile = 32 32 128, warp_tile = 16 16 128, instruction_tile = 16 16 16` where the runtime drops from 12.4 ms to 7.28ms ! Thank you @zasdfgbnm for the help. Values measured on NVIDIA A100 SXM4 80 GB --------- Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
Cherry-pick of the changes made in #87 into main. Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
I am seeing trivial thread predicate in one matmul schedule #87, and that trivial thread predicate causes a 10% slowdown. So I start my work on removing that predicate. My first step is a rewrite of `ParallelDimensionMap`. The new code is much shorter in lines of code, and should have better result (for example, in `FusionParallelDimensionMap3_CUDA`, `blockDim.x` becomes `20`). This PR uses the following equation to calculate the extent of a parallel dim: ```C++ parallel_dim = simplifyExpr(max(extent1, extent2, ..., extentN)); ``` Future work: - [ ] Simplify trivial predicate in mmigdal-nv#1, what we need to do is: - When simplifying all expressions, we should assume `parallel_index < parallel_dim` in additional to https://github.com/NVIDIA/Fuser/blob/3e69f69024ec98c08b72889c1f32871963338bb7/csrc/expr_simplifier.cpp#L152 - Omit the predicate not only when `isExact(ptype)`, but also when the loop extent equals `parallel_dim`. - [ ] Simplify the following to use the extent computed in this PR: https://github.com/NVIDIA/Fuser/blob/548d5d2698f55205cbe84bec8ca2aff2051fb88b/csrc/executor.cpp#L660
Adds a CTA swizzle to change the order in which the tiles of the output matrix are processed.
This swizzle increases data reuse from A and B, when iterating over gridDim.x. Turns out that CTAs are launched in practice by iterating over gridDim.x first (order is unspecified though, it just happens to behave the same). As a result, the current wave will contain CTAs that compute square sub-matrices of C, and so, increase L2 hit rate.
Best factor seems to be 4. This will be part of the heuristics. Setting the factor to 1 disables this swizzle.
On a 8192x8192x8192 matmul with default config, the speedup is about 20%.
On a 16384x16384x16384, runtime drops from 51ms to 38ms so about 26%
An extreme example is following case:
MNK = 6144 6144 6144, layout=NT stages=0, cta_tile = 32 32 128, warp_tile = 16 16 128, instruction_tile = 16 16 16where the runtime drops from 12.4 ms to 7.28ms !Thank you @zasdfgbnm for the help.
Values measured on NVIDIA A100 SXM4 80 GB