[CUDA] GQA CUDA Kernel Fusion and Performance Optimization#26920
[CUDA] GQA CUDA Kernel Fusion and Performance Optimization#26920
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces a significant CUDA kernel optimization for GroupQueryAttention (GQA) by implementing a fused kernel that combines QKV unpacking, Rotary Position Embeddings (RoPE), and KV cache append operations into a single kernel launch. This reduces kernel overhead and memory bandwidth requirements for the first prompt phase.
Key Changes:
- Introduces
UnpackQKVWithRoPEAndAppendKVfused kernel that consolidates 4-5 separate operations - Adds
FlashAttentionDecodingfast path for subsequent prompts/token generation with shared KV buffers - Refactors sequence length handling to use
past_seq_lens,total_seq_lens, andpadded_seq_lensarrays
Reviewed changes
Copilot reviewed 14 out of 15 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| test_gqa.py | Adds test for padding scenarios, updates tolerances, adds share_buffer and position_ids test coverage |
| group_query_attention_impl.cu | Implements fused UnpackQKV+RoPE+KVAppend kernel, adds FlashAttentionDecoding fast path |
| attention_kv_cache.cu | Adds fused KV concat with RoPE support, refactors to use past_seq_lens/total_seq_lens |
| rotary_embedding_impl.cu | Adds position_ids_format=2 for implicit position computation from past_seq_lens |
| flash_api.cc | Adds packed QKV support in flash attention API |
| group_query_attention.cc | Refactors buffer allocation and sequence length handling |
| attention_data.h | Updates data structure for new sequence length arrays and position_ids |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
Outdated
Show resolved
Hide resolved
e768ee0
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 15 out of 16 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Should we consider fusing RMSNorm inside RoPE as well? Some models such as Qwen-3 apply RMSNorm to Q and K after MatMul and before RoPE. |
It is feasible. We can support it in the future. |
…#26920) ## Summary This PR significantly improves GroupQueryAttention (GQA) performance on CUDA by fusing multiple kernel launches, improving memory access patterns, and cleaning up sequence length semantics. ## Key Changes ### 1. Fused Kernels for Reduced Launch Overhead | New Kernel | Operations Fused | Kernels Saved | |------------|------------------|---------------| | `UnpackQKVWithRoPEAndAppendKV` | Unpack packed QKV + RoPE Q/K + KV cache append | 4-5 | | `ConcatNewToPastKVFused` | K append + V append (separate buffer mode) | 1 | | `ConcatKVInPlaceFused` | K append + V append (shared buffer mode) | 1 | ### 2. New `RotaryDispatcher` Template (`rotary_common.cuh`) Reusable RoPE implementation for fused kernels supporting: - `float`, `half`, `BFloat16` element types - `float2`, `float4` vector types - Interleaved and half-split rotation modes ### 3. Sequence Length Semantics Cleanup **Before:** Confusing `seqlens_k` / `seqlens_k_buff` with overloaded meanings. **After:** Clear separation: - `past_seq_lens` - offset where new tokens are appended - `total_seq_lens` - total valid tokens after append - `padded_seq_lens` - padded length for first prompt masking ### 4. FlashAttention Fast Decode Path New optimized path for token generation (`sequence_length == 1`, shared buffer): - Bypasses `GetSequenceLengths` kernel - Passes `past_seq_lens` directly to Flash Attention - Controlled by `ORT_DISABLE_FLASH_DECODE` env var ### 5. Integer Overflow Prevention All KV cache index calculations use `int64_t` to handle large `batch * heads * seq * head_size` products. ### 6. BFloat16 Vectorization Added `float4` (8 elements) vectorized path for BFloat16 in `ConcatTensorToTensor`. ## Environment Variables | Variable | Default | Description | |----------|---------|-------------| | `ORT_DISABLE_FLASH_DECODE` | `false` | Disable fast decode optimization | | `ORT_DISABLE_FUSED_KV` | `false` | Use unfused K/V append kernels | ## Test Changes ### Improved Test Coverage Strategy Restructured `gqa_cuda_prompt_test_cases()` and `gqa_cuda_past_test_cases()` to explicitly iterate over kernel code path parameters: ```python # NEW: Primary iteration over kernel code paths for h in h_sizes_to_test: for packed in packed_opts: for rotary, rotary_interleaved in rotary_opts: for share_buffer in share_buffer_opts: # Secondary params (batch, seq, heads) rotate via modulo ``` | Mode | Before | After | |------|--------|-------| | Pipeline | 16 tests, 4/12 combos | 42 tests, 8/12 combos | | Comprehensive | 81 tests, 4/12 combos | 178 tests, 12/12 combos | ### New Test Parameters - Added `seqs = [(1, 1)]` for edge case testing - Added `heads = [(3, 1)]` for non-standard GQA ratios - Added `h_sizes = [40]` for non-power-of-2 head sizes (tests rotary skip logic) ### New Test Configurations - `share_buffer` config option (tests both buffer modes) - `has_position_ids` testing on CUDA - Padding prompt parity test - Fused vs unfused kernel parity tests (`TestFusedKernelParity`) - Decoding from empty cache test case `(1, 1)` ## Files Changed **Core:** - `group_query_attention_impl.cu` - Main implementation refactoring - `attention_kv_cache.cu` - Fused append kernels - `flash_api.cc` - Packed QKV stride handling **New:** - `rotary_common.cuh` - Reusable RoPE dispatcher **Tests:** - `test_gqa.py` - Extended test coverage ## Performance For decoding or subsequent prompt, we still use original flash attention kernel, so the performance is almost same as baseline. Here we only show the results of first prompt. Below are results of benchmark_gqa.py on H200 GPU. Note that the latency is measured from onnx model of a GQA node, so the latency includes extra cost. The kernel speed up can be larger (See profiling results below). ### prompt-sm90-Llama3-8B-b1-h32_8x128-float16 **Configuration**: `batch=1, prompt (past_seq=0), num_heads=32, kv_heads=8, head_size=128, dtype=float16, gpu=H200` Dense mean Q, K and V are separated inputs. Packed means Q, K and V are packed into one input. | Sequence Length | Dense Base (ms) | Dense Treat (ms) | **Dense Speedup** | Packed Base (ms) | Packed Treat (ms) | **Packed Speedup** | | --------------: | --------------: | ---------------: | :---------------- | ---------------: | ----------------: | :----------------- | | 1024 | 0.470 | 0.277 | **1.70x** | 0.468 | 0.320 | **1.46x** | | 2048 | 1.001 | 0.517 | **1.94x** | 0.990 | 0.590 | **1.68x** | | 4096 | 2.691 | 1.174 | **2.29x** | 1.504 | 1.242 | **1.21x** | | 8192 | 7.780 | 2.292 | **3.39x** | 7.933 | 4.004 | **1.98x** | ### prompt-sm90-Llama3-8B-b1-h32_8x128-bfloat16 **Configuration**: `batch=1, prompt (past_seq=0), num_heads=32, kv_heads=8, head_size=128, dtype=bfloat16, gpu=H200` | Sequence Length | Dense Base (ms) | Dense Treat (ms) | **Dense Speedup** | Packed Base (ms) | Packed Treat (ms) | **Packed Speedup** | | --------------: | --------------: | ---------------: | :---------------- | ---------------: | ----------------: | :----------------- | | 1024 | 0.477 | 0.274 | **1.74x** | 0.486 | 0.332 | **1.46x** | | 2048 | 1.078 | 0.500 | **2.16x** | 1.087 | 0.601 | **1.81x** | | 4096 | 2.633 | 1.144 | **2.30x** | 3.017 | 1.282 | **2.35x** | | 8192 | 7.933 | 2.712 | **2.93x** | 7.933 | 4.003 | **1.98x** | # Profiling Comparison (Prompt Phase) **Summary**: Switching from `flash_fwd_splitkv_kernel` to standard `flash_fwd_kernel` for the prompt phase (SeqLen=2048) results in a **~3x reduction in attention kernel latency** and a **~2x improvement in total operator latency**. ## 1. Packed QKV **Configuration**: `batch=1, seq_len=2048, past_seq=0, num_heads=32, kv_heads=8, head_size=128` | Metric | Baseline | Treatment | Delta | | :--- | :--- | :--- | :--- | | **Total Latency** | **639.3 us** | **287.0 us** | **2.23x Speedup** | | **Attention Kernel** | `flash_fwd_splitkv_kernel`<br>567.10 us | `flash_fwd_kernel`<br>187.70 us | **3.08x Speedup** | | **Helper Kernels** | `ConcatNewToPastKV`: 4.71 us | `UnpackQKVWithRoPEAndAppendKV`: 32.44 us<br>`GetSequenceLengths`: 1.63 us | *Fused ops added* | > **Note**: The Treatment implementation introduces a fused `UnpackQKVWithRoPEAndAppendKV` kernel which performs necessary pre-processing. Despite this added cost (~29 us), the massive gain from using the efficient `flash_fwd_kernel` instead of `flash_fwd_splitkv_kernel` yields a significant net speedup. ## 2. Dense (Separated QKV) **Configuration**: `batch=1, seq_len=2048, past_seq=0, num_heads=32, kv_heads=8, head_size=128` | Metric | Baseline | Treatment | Delta | | :--- | :--- | :--- | :--- | | **Total Latency** | **0.6468 ms** | **0.3226 ms** | **2.00x Speedup** | | **Attention Kernel** | `flash_fwd_splitkv_kernel`<br>567.25 us | `flash_fwd_kernel`<br> 184.29 us | **3.08x Speedup** | | **Helper Kernels** | `ConcatNewToPastKV`: 4.68 us | `RotaryEmbeddingBSNH`: 48.94 us<br>`ConcatNewToPastKVFused`: 13.04 us<br>`GetSequenceLengths`: 1.52 us | *See below* | > **Note**: Similar to the Packed case, the switch to the standard Flash Attention forward kernel drives the performance improvement. The pre-processing is handled by `RotaryEmbeddingBSNH` and `ConcatNewToPastKVFused` in the treatment.
…#26920) This PR significantly improves GroupQueryAttention (GQA) performance on CUDA by fusing multiple kernel launches, improving memory access patterns, and cleaning up sequence length semantics. | New Kernel | Operations Fused | Kernels Saved | |------------|------------------|---------------| | `UnpackQKVWithRoPEAndAppendKV` | Unpack packed QKV + RoPE Q/K + KV cache append | 4-5 | | `ConcatNewToPastKVFused` | K append + V append (separate buffer mode) | 1 | | `ConcatKVInPlaceFused` | K append + V append (shared buffer mode) | 1 | Reusable RoPE implementation for fused kernels supporting: - `float`, `half`, `BFloat16` element types - `float2`, `float4` vector types - Interleaved and half-split rotation modes **Before:** Confusing `seqlens_k` / `seqlens_k_buff` with overloaded meanings. **After:** Clear separation: - `past_seq_lens` - offset where new tokens are appended - `total_seq_lens` - total valid tokens after append - `padded_seq_lens` - padded length for first prompt masking New optimized path for token generation (`sequence_length == 1`, shared buffer): - Bypasses `GetSequenceLengths` kernel - Passes `past_seq_lens` directly to Flash Attention - Controlled by `ORT_DISABLE_FLASH_DECODE` env var All KV cache index calculations use `int64_t` to handle large `batch * heads * seq * head_size` products. Added `float4` (8 elements) vectorized path for BFloat16 in `ConcatTensorToTensor`. | Variable | Default | Description | |----------|---------|-------------| | `ORT_DISABLE_FLASH_DECODE` | `false` | Disable fast decode optimization | | `ORT_DISABLE_FUSED_KV` | `false` | Use unfused K/V append kernels | Restructured `gqa_cuda_prompt_test_cases()` and `gqa_cuda_past_test_cases()` to explicitly iterate over kernel code path parameters: ```python for h in h_sizes_to_test: for packed in packed_opts: for rotary, rotary_interleaved in rotary_opts: for share_buffer in share_buffer_opts: # Secondary params (batch, seq, heads) rotate via modulo ``` | Mode | Before | After | |------|--------|-------| | Pipeline | 16 tests, 4/12 combos | 42 tests, 8/12 combos | | Comprehensive | 81 tests, 4/12 combos | 178 tests, 12/12 combos | - Added `seqs = [(1, 1)]` for edge case testing - Added `heads = [(3, 1)]` for non-standard GQA ratios - Added `h_sizes = [40]` for non-power-of-2 head sizes (tests rotary skip logic) - `share_buffer` config option (tests both buffer modes) - `has_position_ids` testing on CUDA - Padding prompt parity test - Fused vs unfused kernel parity tests (`TestFusedKernelParity`) - Decoding from empty cache test case `(1, 1)` **Core:** - `group_query_attention_impl.cu` - Main implementation refactoring - `attention_kv_cache.cu` - Fused append kernels - `flash_api.cc` - Packed QKV stride handling **New:** - `rotary_common.cuh` - Reusable RoPE dispatcher **Tests:** - `test_gqa.py` - Extended test coverage For decoding or subsequent prompt, we still use original flash attention kernel, so the performance is almost same as baseline. Here we only show the results of first prompt. Below are results of benchmark_gqa.py on H200 GPU. Note that the latency is measured from onnx model of a GQA node, so the latency includes extra cost. The kernel speed up can be larger (See profiling results below). **Configuration**: `batch=1, prompt (past_seq=0), num_heads=32, kv_heads=8, head_size=128, dtype=float16, gpu=H200` Dense mean Q, K and V are separated inputs. Packed means Q, K and V are packed into one input. | Sequence Length | Dense Base (ms) | Dense Treat (ms) | **Dense Speedup** | Packed Base (ms) | Packed Treat (ms) | **Packed Speedup** | | --------------: | --------------: | ---------------: | :---------------- | ---------------: | ----------------: | :----------------- | | 1024 | 0.470 | 0.277 | **1.70x** | 0.468 | 0.320 | **1.46x** | | 2048 | 1.001 | 0.517 | **1.94x** | 0.990 | 0.590 | **1.68x** | | 4096 | 2.691 | 1.174 | **2.29x** | 1.504 | 1.242 | **1.21x** | | 8192 | 7.780 | 2.292 | **3.39x** | 7.933 | 4.004 | **1.98x** | **Configuration**: `batch=1, prompt (past_seq=0), num_heads=32, kv_heads=8, head_size=128, dtype=bfloat16, gpu=H200` | Sequence Length | Dense Base (ms) | Dense Treat (ms) | **Dense Speedup** | Packed Base (ms) | Packed Treat (ms) | **Packed Speedup** | | --------------: | --------------: | ---------------: | :---------------- | ---------------: | ----------------: | :----------------- | | 1024 | 0.477 | 0.274 | **1.74x** | 0.486 | 0.332 | **1.46x** | | 2048 | 1.078 | 0.500 | **2.16x** | 1.087 | 0.601 | **1.81x** | | 4096 | 2.633 | 1.144 | **2.30x** | 3.017 | 1.282 | **2.35x** | | 8192 | 7.933 | 2.712 | **2.93x** | 7.933 | 4.003 | **1.98x** | **Summary**: Switching from `flash_fwd_splitkv_kernel` to standard `flash_fwd_kernel` for the prompt phase (SeqLen=2048) results in a **~3x reduction in attention kernel latency** and a **~2x improvement in total operator latency**. **Configuration**: `batch=1, seq_len=2048, past_seq=0, num_heads=32, kv_heads=8, head_size=128` | Metric | Baseline | Treatment | Delta | | :--- | :--- | :--- | :--- | | **Total Latency** | **639.3 us** | **287.0 us** | **2.23x Speedup** | | **Attention Kernel** | `flash_fwd_splitkv_kernel`<br>567.10 us | `flash_fwd_kernel`<br>187.70 us | **3.08x Speedup** | | **Helper Kernels** | `ConcatNewToPastKV`: 4.71 us | `UnpackQKVWithRoPEAndAppendKV`: 32.44 us<br>`GetSequenceLengths`: 1.63 us | *Fused ops added* | > **Note**: The Treatment implementation introduces a fused `UnpackQKVWithRoPEAndAppendKV` kernel which performs necessary pre-processing. Despite this added cost (~29 us), the massive gain from using the efficient `flash_fwd_kernel` instead of `flash_fwd_splitkv_kernel` yields a significant net speedup. **Configuration**: `batch=1, seq_len=2048, past_seq=0, num_heads=32, kv_heads=8, head_size=128` | Metric | Baseline | Treatment | Delta | | :--- | :--- | :--- | :--- | | **Total Latency** | **0.6468 ms** | **0.3226 ms** | **2.00x Speedup** | | **Attention Kernel** | `flash_fwd_splitkv_kernel`<br>567.25 us | `flash_fwd_kernel`<br> 184.29 us | **3.08x Speedup** | | **Helper Kernels** | `ConcatNewToPastKV`: 4.68 us | `RotaryEmbeddingBSNH`: 48.94 us<br>`ConcatNewToPastKVFused`: 13.04 us<br>`GetSequenceLengths`: 1.52 us | *See below* | > **Note**: Similar to the Packed case, the switch to the standard Flash Attention forward kernel drives the performance improvement. The pre-processing is handled by `RotaryEmbeddingBSNH` and `ConcatNewToPastKVFused` in the treatment.
Summary
This PR significantly improves GroupQueryAttention (GQA) performance on CUDA by fusing multiple kernel launches, improving memory access patterns, and cleaning up sequence length semantics.
Key Changes
1. Fused Kernels for Reduced Launch Overhead
UnpackQKVWithRoPEAndAppendKVConcatNewToPastKVFusedConcatKVInPlaceFused2. New
RotaryDispatcherTemplate (rotary_common.cuh)Reusable RoPE implementation for fused kernels supporting:
float,half,BFloat16element typesfloat2,float4vector types3. Sequence Length Semantics Cleanup
Before: Confusing
seqlens_k/seqlens_k_buffwith overloaded meanings.After: Clear separation:
past_seq_lens- offset where new tokens are appendedtotal_seq_lens- total valid tokens after appendpadded_seq_lens- padded length for first prompt masking4. FlashAttention Fast Decode Path
New optimized path for token generation (
sequence_length == 1, shared buffer):GetSequenceLengthskernelpast_seq_lensdirectly to Flash AttentionORT_DISABLE_FLASH_DECODEenv var5. Integer Overflow Prevention
All KV cache index calculations use
int64_tto handle largebatch * heads * seq * head_sizeproducts.6. BFloat16 Vectorization
Added
float4(8 elements) vectorized path for BFloat16 inConcatTensorToTensor.Environment Variables
ORT_DISABLE_FLASH_DECODEfalseORT_DISABLE_FUSED_KVfalseTest Changes
Improved Test Coverage Strategy
Restructured
gqa_cuda_prompt_test_cases()andgqa_cuda_past_test_cases()to explicitly iterate over kernel code path parameters:New Test Parameters
seqs = [(1, 1)]for edge case testingheads = [(3, 1)]for non-standard GQA ratiosh_sizes = [40]for non-power-of-2 head sizes (tests rotary skip logic)New Test Configurations
share_bufferconfig option (tests both buffer modes)has_position_idstesting on CUDATestFusedKernelParity)(1, 1)Files Changed
Core:
group_query_attention_impl.cu- Main implementation refactoringattention_kv_cache.cu- Fused append kernelsflash_api.cc- Packed QKV stride handlingNew:
rotary_common.cuh- Reusable RoPE dispatcherTests:
test_gqa.py- Extended test coveragePerformance
For decoding or subsequent prompt, we still use original flash attention kernel, so the performance is almost same as baseline. Here we only show the results of first prompt.
Below are results of benchmark_gqa.py on H200 GPU. Note that the latency is measured from onnx model of a GQA node, so the latency includes extra cost. The kernel speed up can be larger (See profiling results below).
prompt-sm90-Llama3-8B-b1-h32_8x128-float16
Configuration:
batch=1, prompt (past_seq=0), num_heads=32, kv_heads=8, head_size=128, dtype=float16, gpu=H200Dense mean Q, K and V are separated inputs. Packed means Q, K and V are packed into one input.
prompt-sm90-Llama3-8B-b1-h32_8x128-bfloat16
Configuration:
batch=1, prompt (past_seq=0), num_heads=32, kv_heads=8, head_size=128, dtype=bfloat16, gpu=H200Profiling Comparison (Prompt Phase)
Summary:
Switching from
flash_fwd_splitkv_kernelto standardflash_fwd_kernelfor the prompt phase (SeqLen=2048) results in a ~3x reduction in attention kernel latency and a ~2x improvement in total operator latency.1. Packed QKV
Configuration:
batch=1, seq_len=2048, past_seq=0, num_heads=32, kv_heads=8, head_size=128flash_fwd_splitkv_kernel567.10 us
flash_fwd_kernel187.70 us
ConcatNewToPastKV: 4.71 usUnpackQKVWithRoPEAndAppendKV: 32.44 usGetSequenceLengths: 1.63 us2. Dense (Separated QKV)
Configuration:
batch=1, seq_len=2048, past_seq=0, num_heads=32, kv_heads=8, head_size=128flash_fwd_splitkv_kernel567.25 us
flash_fwd_kernel184.29 us
ConcatNewToPastKV: 4.68 usRotaryEmbeddingBSNH: 48.94 usConcatNewToPastKVFused: 13.04 usGetSequenceLengths: 1.52 us