Skip to content

CTA Swizzles#87

Merged
mmigdal-nv merged 3 commits intotracking-matmulfrom
MM/cta_swizzle
Mar 29, 2023
Merged

CTA Swizzles#87
mmigdal-nv merged 3 commits intotracking-matmulfrom
MM/cta_swizzle

Conversation

@mmigdal-nv
Copy link
Collaborator

@mmigdal-nv mmigdal-nv commented Mar 28, 2023

Adds a CTA swizzle to change the order in which the tiles of the output matrix are processed.
This swizzle increases data reuse from A and B, when iterating over gridDim.x. Turns out that CTAs are launched in practice by iterating over gridDim.x first (order is unspecified though, it just happens to behave the same). As a result, the current wave will contain CTAs that compute square sub-matrices of C, and so, increase L2 hit rate.

Best factor seems to be 4. This will be part of the heuristics. Setting the factor to 1 disables this swizzle.

On a 8192x8192x8192 matmul with default config, the speedup is about 20%.
On a 16384x16384x16384, runtime drops from 51ms to 38ms so about 26%

An extreme example is following case: MNK = 6144 6144 6144, layout=NT stages=0, cta_tile = 32 32 128, warp_tile = 16 16 128, instruction_tile = 16 16 16 where the runtime drops from 12.4 ms to 7.28ms !

Thank you @zasdfgbnm for the help.
Values measured on NVIDIA A100 SXM4 80 GB

@mmigdal-nv mmigdal-nv requested a review from zasdfgbnm March 28, 2023 23:13
Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work! LGTM except some minor comments.
Please cherry-pick this to main without the change in csrc/predicate_compute.cpp.
This will add a trivial predicate, but the perf should still be faster than the perf without swizzle.

//! A2 A3 B3 B4 C1 C2 C3 C4 D1 D2 D3 D4
//! C1 C2 D1 D2
//! C3 C4 D3 D4
int swizzle_factor = 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call this grid_swizzle_factor to distinguish from the prologue "swizzle" that is used to remove bank conflict?

auto inputs = fp16MatmulAtInput(M, N, K, layout);

FusionExecutor fe;
fe.setMeasureKernelTimeFlag(true);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will add some debug printting when running this test? Could you comment this line out?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No debug prints. The only effect is to create cudaEvents and return the runtime through fe.kernelTimeMs() (otherwise it's just zero).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh, I thought it would automatically std::cout << fe.kernelTimeMs(). If no debug prints, then we can keep it.


// Gmem pipeline stage

for (auto layout : {MatmulLayout::TT}) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be kAllSupportedMatmulLayout?

Copy link
Collaborator Author

@mmigdal-nv mmigdal-nv Mar 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just removed it to keep a short runtime as the test checks four configs per layout already

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current test takes 15s, would jump to 45s.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, then let's just test one layout

mmigdal-nv and others added 2 commits March 29, 2023 01:35
Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
@mmigdal-nv mmigdal-nv merged commit 01ca45f into tracking-matmul Mar 29, 2023
@mmigdal-nv mmigdal-nv deleted the MM/cta_swizzle branch March 29, 2023 08:31
mmigdal-nv added a commit that referenced this pull request Mar 29, 2023
Adds a CTA swizzle to change the order in which the tiles of the output
matrix are processed.
This swizzle increases data reuse from A and B, when iterating over
gridDim.x. Turns out that CTAs are launched in practice by iterating
over gridDim.x first (order is unspecified though, it just happens to
behave the same). As a result, the current wave will contain CTAs that
compute square sub-matrices of C, and so, increase L2 hit rate.

Best factor seems to be 4. This will be part of the heuristics. Setting
the factor to 1 disables this swizzle.

On a 8192x8192x8192 matmul with default config, the speedup is about
20%.

An extreme example is following case: `MNK = 6144 6144 6144, layout=NT
stages=0, cta_tile = 32 32 128, warp_tile = 16 16 128, instruction_tile
= 16 16 16` where the runtime drops from 12.4 ms to 7.28ms !

Thank you @zasdfgbnm for the help.
Values measured on NVIDIA A100 SXM4 80 GB

---------

Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
@mmigdal-nv mmigdal-nv mentioned this pull request Mar 29, 2023
mmigdal-nv added a commit that referenced this pull request Mar 29, 2023
Cherry-pick of the changes made in
#87 into main.

Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
zasdfgbnm added a commit that referenced this pull request Mar 29, 2023
I am seeing trivial thread predicate in one matmul schedule
#87, and that trivial thread
predicate causes a 10% slowdown. So I start my work on removing that
predicate. My first step is a rewrite of `ParallelDimensionMap`. The new
code is much shorter in lines of code, and should have better result
(for example, in `FusionParallelDimensionMap3_CUDA`, `blockDim.x`
becomes `20`).

This PR uses the following equation to calculate the extent of a
parallel dim:
```C++
parallel_dim = simplifyExpr(max(extent1, extent2, ..., extentN));
```

Future work:
- [ ] Simplify trivial predicate in
mmigdal-nv#1, what we need to do is:
- When simplifying all expressions, we should assume `parallel_index <
parallel_dim` in additional to
https://github.com/NVIDIA/Fuser/blob/3e69f69024ec98c08b72889c1f32871963338bb7/csrc/expr_simplifier.cpp#L152
- Omit the predicate not only when `isExact(ptype)`, but also when the
loop extent equals `parallel_dim`.
- [ ] Simplify the following to use the extent computed in this PR:

https://github.com/NVIDIA/Fuser/blob/548d5d2698f55205cbe84bec8ca2aff2051fb88b/csrc/executor.cpp#L660
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants