diff --git a/aiter/ops/triton/_triton_kernels/fp8_mqa_logits.py b/aiter/ops/triton/_triton_kernels/fp8_mqa_logits.py index bc58d2421f..b8fd6949ba 100644 --- a/aiter/ops/triton/_triton_kernels/fp8_mqa_logits.py +++ b/aiter/ops/triton/_triton_kernels/fp8_mqa_logits.py @@ -29,7 +29,9 @@ def _fp8_mqa_logits_kernel( BLOCK_KV: tl.constexpr, ): row_id = tl.program_id(0) - + # go from larger to smaller in terms of work + # to reduce the tail effect + row_id = tl.num_programs(0) - row_id - 1 tl.assume(row_id >= 0) tl.assume(stride_q_s > 0) tl.assume(stride_q_h > 0) @@ -39,6 +41,8 @@ def _fp8_mqa_logits_kernel( tl.assume(stride_w_s > 0) tl.assume(stride_w_h > 0) + logits_row_ptrs = logits_ptr + row_id * stride_logits_s + h_inds = tl.arange(0, NUM_HEADS)[:, None] d_inds = tl.arange(0, HEAD_SIZE) @@ -57,9 +61,9 @@ def _fp8_mqa_logits_kernel( start_ind = tl.maximum(start_ind, 0) end_ind = tl.minimum(end_ind, seq_len_kv) - unmasked_end_ind = (end_ind // BLOCK_KV) * BLOCK_KV + shifted_end = end_ind - start_ind + shifted_unmasked_end = shifted_end // BLOCK_KV * BLOCK_KV - logits_row_ptrs = logits_ptr + row_id * stride_logits_s kv_col_offsets = tl.arange(0, BLOCK_KV) + start_ind kv_ptrs = ( KV_ptr + kv_col_offsets[None, :] * stride_kv_s + d_inds[:, None] * stride_kv_d @@ -70,12 +74,12 @@ def _fp8_mqa_logits_kernel( logits_ptrs = logits_row_ptrs + kv_col_offsets * stride_logits_k # Loop over KV tiles - for _ in tl.range(start_ind, unmasked_end_ind, BLOCK_KV): + for _ in tl.range(0, shifted_unmasked_end, BLOCK_KV): kv_block = tl.load(kv_ptrs) kv_scales = tl.load(kv_scales_ptrs) # [NUM_HEADS, BLOCK_KV] = [NUM_HEADS, HEAD_SIZE] x [HEAD_SIZE, BLOCK_KV] - scores = tl.dot(q_block, kv_block) + scores = tl.dot(q_block, kv_block, input_precision="ieee") # Multiply by kv_scales (broadcast along rows) scores = scores * kv_scales[None, :] # ReLU @@ -88,23 +92,22 @@ def _fp8_mqa_logits_kernel( kv_ptrs += BLOCK_KV * stride_kv_s kv_scales_ptrs += BLOCK_KV logits_ptrs += BLOCK_KV * stride_logits_k - - if unmasked_end_ind != end_ind: - # masked load - kv_col_offsets = tl.arange(0, BLOCK_KV) + unmasked_end_ind - kv_col_mask = kv_col_offsets < seq_len_kv - kv_block = tl.load(kv_ptrs, mask=kv_col_mask[None, :], other=0.0) - kv_scales = tl.load(kv_scales_ptrs, mask=kv_col_mask, other=0.0) - - # [NUM_HEADS, BLOCK_KV] = [NUM_HEADS, HEAD_SIZE] x [HEAD_SIZE, BLOCK_KV] - scores = tl.dot(q_block, kv_block) - # Multiply by kv_scales (broadcast along rows) - scores = scores * kv_scales[None, :] - # ReLU - scores = tl.maximum(scores, 0.0) - scores = scores * w_block - # [NUM_HEADS, BLOCK_KV] -> [BLOCK_KV, ] - scores = tl.sum(scores, axis=0) - # masked store - in_window = (kv_col_offsets >= start_ind) & (kv_col_offsets < end_ind) - tl.store(logits_ptrs, scores, mask=in_window) + kv_col_offsets += BLOCK_KV + + # masked load + kv_col_mask = kv_col_offsets < end_ind + kv_block = tl.load(kv_ptrs, mask=kv_col_mask[None, :], other=0.0) + kv_scales = tl.load(kv_scales_ptrs, mask=kv_col_mask, other=0.0) + + # [NUM_HEADS, BLOCK_KV] = [NUM_HEADS, HEAD_SIZE] x [HEAD_SIZE, BLOCK_KV] + scores = tl.dot(q_block, kv_block, input_precision="ieee") + # Multiply by kv_scales (broadcast along rows) + scores = scores * kv_scales[None, :] + # ReLU + scores = tl.maximum(scores, 0.0) + scores = scores * w_block + # [NUM_HEADS, BLOCK_KV] -> [BLOCK_KV, ] + scores = tl.sum(scores, axis=0) + # masked store + in_window = (kv_col_offsets >= start_ind) & (kv_col_offsets < end_ind) + tl.store(logits_ptrs, scores, mask=in_window) diff --git a/aiter/ops/triton/fp8_mqa_logits.py b/aiter/ops/triton/fp8_mqa_logits.py index 512973b853..ad27752d1c 100644 --- a/aiter/ops/triton/fp8_mqa_logits.py +++ b/aiter/ops/triton/fp8_mqa_logits.py @@ -44,6 +44,12 @@ def fp8_mqa_logits( stride_kv_s, stride_kv_d = KV.stride() stride_w_s, stride_w_h = weights.stride() stride_logits_s, stride_logits_k = logits.stride() + + # heuristic for MFMA instruction shape + matrix_instr_nonkdim = 32 + if seq_len <= 1024: + matrix_instr_nonkdim = 16 + _fp8_mqa_logits_kernel[(seq_len,)]( Q_ptr=Q, KV_ptr=KV, @@ -69,7 +75,7 @@ def fp8_mqa_logits( num_warps=4, num_stages=2, waves_per_eu=2, - matrix_instr_nonkdim=16, + matrix_instr_nonkdim=matrix_instr_nonkdim, ) return logits diff --git a/op_tests/op_benchmarks/triton/bench_fp8_mqa_logits.py b/op_tests/op_benchmarks/triton/bench_fp8_mqa_logits.py new file mode 100644 index 0000000000..ae73c81f1a --- /dev/null +++ b/op_tests/op_benchmarks/triton/bench_fp8_mqa_logits.py @@ -0,0 +1,117 @@ +import torch +import triton +import argparse +from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits +from aiter.ops.triton.utils.types import e4m3_dtype +from op_tests.triton_tests.test_fp8_mqa_logits import ( + per_custom_dims_cast_to_fp8, + generate_cp_test_data, +) +from op_tests.op_benchmarks.triton.utils.benchmark_utils import ( + print_vgpr, + get_caller_name_no_ext, +) + + +def calculate_tflops(start_inds, end_inds, num_heads_q, head_dim, time_ms): + time_s = time_ms * 1e-3 + start_inds = start_inds.to("cpu").numpy() + end_inds = end_inds.to("cpu").numpy() + total_flops = 0.0 + for i in range(len(start_inds)): + start = start_inds[i] + end = end_inds[i] + total_flops += 2.0 * num_heads_q * head_dim * (end - start) + # TFLOPs = total FLOPs / (time in seconds * 1e12) + tflops = total_flops / (time_s * 1e12) + + return tflops + + +def run_benchmark(args): + x_names = ["seq_q_l", "seq_kv_l", "num_heads_q", "head_dim"] + x_vals_list = [[args.seq_q_l, args.seq_kv_l, args.num_heads_q, args.head_dim]] + if args.metric == "time": + ylabel = "Time (ms)" + elif args.metric == "throughput": + ylabel = "TFLOPs" + else: + raise NotImplementedError(f"{args.metric} is not supported") + + line_names = [ylabel] + line_vals = [ylabel] + benchmark = triton.testing.Benchmark( + x_names=x_names, + x_vals=x_vals_list, + line_arg="unit", + line_vals=line_vals, + line_names=line_names, + styles=[("green", "-")], + ylabel=ylabel, + plot_name=get_caller_name_no_ext(), + args={"metric": args.metric}, + ) + + @triton.testing.perf_report([benchmark]) + def bench_fp8_mqa_logits( + seq_q_l, seq_kv_l, num_heads_q, head_dim, metric, **kwargs + ): + q = torch.randn( + seq_q_l, num_heads_q, head_dim, device="cuda", dtype=torch.bfloat16 + ) + kv = torch.randn(seq_kv_l, head_dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(seq_q_l, num_heads_q, device="cuda", dtype=torch.float32) + + ks = torch.zeros(seq_q_l, dtype=torch.int, device="cuda") + ke = torch.arange(seq_q_l, dtype=torch.int, device="cuda") + ( + seq_kv_l - seq_q_l + ) + + q_fp8 = q.to(e4m3_dtype) + kv_fp8, scales = per_custom_dims_cast_to_fp8(kv, (0,), False) + + func = lambda: fp8_mqa_logits(q_fp8, kv_fp8, scales, weights, ks, ke) + + time_ms = triton.testing.do_bench(func, warmup=25, rep=100) + tflops = calculate_tflops(ks, ke, num_heads_q, head_dim, time_ms) + + # Return exactly one scalar depending on which metric is active + if metric == "time": + return time_ms + elif metric == "throughput": + return tflops + else: + raise ValueError("Unknown metric: " + metric) + + bench_fp8_mqa_logits.run(save_path="." if args.o else None, print_data=True) + + +def main(): + parser = argparse.ArgumentParser( + description="FP8 MQA Logits Benchmark", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--num_heads_q", type=int, default=64, help="num. q heads") + parser.add_argument("--head_dim", type=int, default=128, help="head dim size") + parser.add_argument( + "--seq_q_l", type=int, default=4096, help="Input sequence length" + ) + parser.add_argument( + "--seq_kv_l", type=int, default=4096, help="Output sequence length" + ) + parser.add_argument( + "-o", action="store_true", help="Write performance results to CSV file" + ) + parser.add_argument( + "--metric", + type=str, + choices=["time", "throughput"], + default="throughput", + help="metric to plot", + ) + args = parser.parse_args() + run_benchmark(args) + + +if __name__ == "__main__": + main()