-
Notifications
You must be signed in to change notification settings - Fork 172
[TRITON]: Fp4gemm m=256 tuning #533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@Chi-Chu319 LGTM! |
|
Also, I was able to reproduce the performance numbers and everything looks good! |
rahulbatra85
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good except the conflict. Also, can you please run the config changes by Cagri Mehmut and Ali?
Thanks!
|
I compared the performance against main and besides the following shapes there is no regression, only performance improvement or no change. "x" means all M dimensions I tried. This is likely from the missing freshly changed configs.
|
I am tuning them. Most of them are group size related. As with remapping you want smaller group size |
|
Looks good to me. No regression for any shapes I tested (LLAMA and DS related ones), only improvements or no change. |
juuso-oskari
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Authors: @Chi-Chu319 @juuso-oskari
This PR is made primarily for M=256, N=16384 tuning and xcd remapping with some tunings
The full performance is available at
https://amdcloud-my.sharepoint.com/:x:/r/personal/alizaidy_amd_com/_layouts/15/doc2.aspx?sourcedoc=%7B2117442c-b906-49c5-8ace-0c07b925dc14%7D&action=edit&activeCell=%27fp4gemm_m256-tuning
We move away from split k because we can satisfy the parallelism with GRID_MN while still respecting the constraint of num_warps * 32 <= BLOCK_N (for the preshuffling scales to work).
We also made a chunked version of the xcd_remap, which now brings perf boost (as opposed to previously degrading perf). The benefit from the chunked version is:
With previous remapping, we divided all the pids to 8 chunks and send those to the 8 XCDs. This effectively made a single XCD process its own continuous chunk of B matrix of size (K x N // 8). This is good for L2 usage, but most likely the L2 is already saturated by caching the A matrix of size (M x K). Its also at the same time disasterous for L3 caching, because now the concurrent memory reads coming from different XCDs are separated maximally with N//8 elements.
We solved this by having the xcd_remap instead of mapping into one continous chunk of size num_pid_n//8, mapping into multiple continuous chunks of size CHUNK_SIZE (a tunable variable).
Performance
The performance for M=256, N=16384:
_gemm_afp4_wfp4_kernel
main (commit e31c2e0)
fp4gemm_m256-tuning
gemm_afp4wfp4_preshuffled_scales
main (commit e31c2e0)
fp4gemm_m256-tuning
the performance of varying m (we show that the pid remapping doesn't hurt the perf but also improve perf in some cases)