-
Notifications
You must be signed in to change notification settings - Fork 171
[TRITON] FP8 MQA optimizations #1422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
4c4f729
add sparse attn. kernels
cagrikymk b10497f
formatting
cagrikymk afdfd81
add comments
cagrikymk bed26a4
formatting
cagrikymk de49d68
comment changes
cagrikymk 02079a7
fp8 mqa kernel optimizations
cagrikymk 069118d
formatting
cagrikymk 205ba1c
fp8 mqa updates
cagrikymk 53c8017
Merge branch 'main' into cagri/sparse_mla_attn_upstream
cagrikymk 6a89d5d
renaming
cagrikymk 57ea55e
formatting
cagrikymk d51e75a
fix comments
cagrikymk da77e58
Merge branch 'main' into cagri/sparse_mla_attn_upstream
cagrikymk 3216d3d
nonkdim change for better perf
cagrikymk 0f0253c
add mqa bench
cagrikymk d74b829
help msgs
cagrikymk ce23e9e
update bench code
cagrikymk ba72f16
formatting
cagrikymk 4f637e6
more optimizations
cagrikymk 67d470a
Merge branch 'main' into cagri/sparse_mla_attn_upstream
cagrikymk e632edd
cleanup
cagrikymk 70b58f7
Merge branch 'main' into cagri/sparse_mla_attn_upstream
cagrikymk 19ed0d1
loop bound changes
cagrikymk bfe4cb1
kernel name fix
cagrikymk 33f0541
formatting changes
cagrikymk 9792c1a
Merge branch 'main' into cagri/sparse_mla_attn_upstream
cagrikymk 5233364
mqa bench fix and formatting
cagrikymk 3952c05
Merge branch 'main' into cagri/sparse_mla_attn_upstream
cagrikymk File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,117 @@ | ||
| import torch | ||
| import triton | ||
| import argparse | ||
| from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits | ||
| from aiter.ops.triton.utils.types import e4m3_dtype | ||
| from op_tests.triton_tests.test_fp8_mqa_logits import ( | ||
| per_custom_dims_cast_to_fp8, | ||
| generate_cp_test_data, | ||
| ) | ||
| from op_tests.op_benchmarks.triton.utils.benchmark_utils import ( | ||
| print_vgpr, | ||
| get_caller_name_no_ext, | ||
| ) | ||
|
|
||
|
|
||
| def calculate_tflops(start_inds, end_inds, num_heads_q, head_dim, time_ms): | ||
| time_s = time_ms * 1e-3 | ||
| start_inds = start_inds.to("cpu").numpy() | ||
| end_inds = end_inds.to("cpu").numpy() | ||
| total_flops = 0.0 | ||
| for i in range(len(start_inds)): | ||
| start = start_inds[i] | ||
| end = end_inds[i] | ||
| total_flops += 2.0 * num_heads_q * head_dim * (end - start) | ||
| # TFLOPs = total FLOPs / (time in seconds * 1e12) | ||
| tflops = total_flops / (time_s * 1e12) | ||
|
|
||
| return tflops | ||
|
|
||
|
|
||
| def run_benchmark(args): | ||
| x_names = ["seq_q_l", "seq_kv_l", "num_heads_q", "head_dim"] | ||
| x_vals_list = [[args.seq_q_l, args.seq_kv_l, args.num_heads_q, args.head_dim]] | ||
| if args.metric == "time": | ||
| ylabel = "Time (ms)" | ||
| elif args.metric == "throughput": | ||
| ylabel = "TFLOPs" | ||
| else: | ||
| raise NotImplementedError(f"{args.metric} is not supported") | ||
|
|
||
| line_names = [ylabel] | ||
| line_vals = [ylabel] | ||
| benchmark = triton.testing.Benchmark( | ||
| x_names=x_names, | ||
| x_vals=x_vals_list, | ||
| line_arg="unit", | ||
| line_vals=line_vals, | ||
| line_names=line_names, | ||
| styles=[("green", "-")], | ||
| ylabel=ylabel, | ||
| plot_name=get_caller_name_no_ext(), | ||
| args={"metric": args.metric}, | ||
| ) | ||
|
|
||
| @triton.testing.perf_report([benchmark]) | ||
| def bench_fp8_mqa_logits( | ||
| seq_q_l, seq_kv_l, num_heads_q, head_dim, metric, **kwargs | ||
| ): | ||
| q = torch.randn( | ||
| seq_q_l, num_heads_q, head_dim, device="cuda", dtype=torch.bfloat16 | ||
| ) | ||
| kv = torch.randn(seq_kv_l, head_dim, device="cuda", dtype=torch.bfloat16) | ||
| weights = torch.randn(seq_q_l, num_heads_q, device="cuda", dtype=torch.float32) | ||
|
|
||
| ks = torch.zeros(seq_q_l, dtype=torch.int, device="cuda") | ||
| ke = torch.arange(seq_q_l, dtype=torch.int, device="cuda") + ( | ||
| seq_kv_l - seq_q_l | ||
| ) | ||
|
|
||
| q_fp8 = q.to(e4m3_dtype) | ||
| kv_fp8, scales = per_custom_dims_cast_to_fp8(kv, (0,), False) | ||
|
|
||
| func = lambda: fp8_mqa_logits(q_fp8, kv_fp8, scales, weights, ks, ke) | ||
|
|
||
| time_ms = triton.testing.do_bench(func, warmup=25, rep=100) | ||
| tflops = calculate_tflops(ks, ke, num_heads_q, head_dim, time_ms) | ||
|
|
||
| # Return exactly one scalar depending on which metric is active | ||
| if metric == "time": | ||
| return time_ms | ||
| elif metric == "throughput": | ||
| return tflops | ||
| else: | ||
| raise ValueError("Unknown metric: " + metric) | ||
|
|
||
| bench_fp8_mqa_logits.run(save_path="." if args.o else None, print_data=True) | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser( | ||
| description="FP8 MQA Logits Benchmark", | ||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||
| ) | ||
| parser.add_argument("--num_heads_q", type=int, default=64, help="num. q heads") | ||
| parser.add_argument("--head_dim", type=int, default=128, help="head dim size") | ||
| parser.add_argument( | ||
| "--seq_q_l", type=int, default=4096, help="Input sequence length" | ||
| ) | ||
| parser.add_argument( | ||
| "--seq_kv_l", type=int, default=4096, help="Output sequence length" | ||
| ) | ||
| parser.add_argument( | ||
| "-o", action="store_true", help="Write performance results to CSV file" | ||
| ) | ||
| parser.add_argument( | ||
| "--metric", | ||
| type=str, | ||
| choices=["time", "throughput"], | ||
| default="throughput", | ||
| help="metric to plot", | ||
| ) | ||
| args = parser.parse_args() | ||
| run_benchmark(args) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.