[CUDA] Update Flash Attention Implementation and APIs#26937
Merged
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR updates the Flash Attention implementation in ONNX Runtime by syncing with the latest kernels from the upstream flash-attention repository and extending the internal API to support advanced caching scenarios with non-contiguous batch indices and left-padding.
Key Changes:
- Extended
mha_fwdandmha_fwd_kvcacheAPIs with two new optional parameters:cache_batch_idx(for non-contiguous batch indexing) andleftpad_k(for left-padding support) - Introduced namespace configuration system via
namespace_config.hfor better isolation and flexibility - Added explicit causal template parameter to kernel dispatching functions, creating separate instantiations for causal and non-causal attention patterns
- Updated numerous kernel files to align with the new template signatures and namespace conventions
- Standardized copyright headers to 2024
Reviewed changes
Copilot reviewed 64 out of 64 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
flash_api.h / flash_api.cc |
Extended API signatures to accept cache_batch_idx and leftpad_k parameters, updated force split kernel logic |
namespace_config.h |
New file defining FLASH_NAMESPACE macro for namespace management |
flash.h |
Updated function templates to include Is_causal template parameter, added new params struct fields |
block_info.h |
Added leftpad_k field and integrated it into offset calculations |
flash_fwd_kernel.h |
Major updates including LSE layout handling, leftpad support in rotary embeddings, return softmax support |
flash_fwd_launch_template.h |
Updated kernel dispatch logic with causal template parameter, modified smem size handling |
kernel_traits.h |
Changed copy atoms from DefaultCopy to AutoVectorizingCopyWithAssumedAlignment<128> for better performance |
utils.h, softmax.h, mask.h, rotary.h |
Updated namespace declarations from onnxruntime::flash to FLASH_NAMESPACE |
flash_fwd_hdim*_*.cu |
All kernel instantiation files updated with new template signatures (added Is_causal parameter) |
flash_fwd_split_hdim*_*.cu |
All split kernel files updated with new template signatures |
update_kernels.py |
Python script to generate kernel files programmatically |
group_query_attention_impl.cu |
Updated caller to pass two additional nullptr parameters to mha_fwd_kvcache |
onnxruntime_providers_cpu.cmake |
Updated build filter comments for quick build mode |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
kunal-vaishnavi
previously approved these changes
Jan 9, 2026
kunal-vaishnavi
approved these changes
Jan 10, 2026
alex-spacemit
pushed a commit
to spacemit-com/onnxruntime
that referenced
this pull request
Jan 20, 2026
## Summary This PR updates the Flash Attention implementation in ONNX Runtime, syncing with newer kernel sources in https://github.com/Dao-AILab/flash-attention, and extending the internal API to support additional features required for advanced caching scenarios. It also aligns specific kernels with the official implementation. ## Changes - **Flash Attention Kernels**: Updated/Added Flash Attention forward kernels and headers in `onnxruntime/contrib_ops/cuda/bert/flash_attention/`. - **API Extension**: Updated `mha_fwd` and `mha_fwd_kvcache` in `flash_api.h` and `flash_api.cc` to accept two new optional parameters: - `cache_batch_idx`: Indices to index into the KV cache (support for non-contiguous batch indices). - `leftpad_k`: Support for left-padding in the key sequence. - **Alignment & Fixes**: - **Cleanup**: Removed redundant `kInfinity` definition in `flash_fwd_kernel.h`. - **Includes**: Added missing `<core/providers/cuda/shared_inc/cuda_call.h>` in `flash_fwd_launch_template.h`. - **Integration**: Updated `group_query_attention_impl.cu` to align with the new `mha_fwd_kvcache` signature. - **Build Configuration**: Adjusted `onnxruntime_providers_cpu.cmake` to update the exclusion list for Flash Attention kernels in quick build mode. ## Implementation Details - The `run_mha_fwd` helper now checks if `cache_batch_idx` is provided alongside `k_new` to determine if the split kernel should be forced. - New parameters are propagated through the call stack to the underlying Flash Attention kernels.
alex-spacemit
pushed a commit
to spacemit-com/onnxruntime
that referenced
this pull request
Jan 27, 2026
## Summary This PR updates the Flash Attention implementation in ONNX Runtime, syncing with newer kernel sources in https://github.com/Dao-AILab/flash-attention, and extending the internal API to support additional features required for advanced caching scenarios. It also aligns specific kernels with the official implementation. ## Changes - **Flash Attention Kernels**: Updated/Added Flash Attention forward kernels and headers in `onnxruntime/contrib_ops/cuda/bert/flash_attention/`. - **API Extension**: Updated `mha_fwd` and `mha_fwd_kvcache` in `flash_api.h` and `flash_api.cc` to accept two new optional parameters: - `cache_batch_idx`: Indices to index into the KV cache (support for non-contiguous batch indices). - `leftpad_k`: Support for left-padding in the key sequence. - **Alignment & Fixes**: - **Cleanup**: Removed redundant `kInfinity` definition in `flash_fwd_kernel.h`. - **Includes**: Added missing `<core/providers/cuda/shared_inc/cuda_call.h>` in `flash_fwd_launch_template.h`. - **Integration**: Updated `group_query_attention_impl.cu` to align with the new `mha_fwd_kvcache` signature. - **Build Configuration**: Adjusted `onnxruntime_providers_cpu.cmake` to update the exclusion list for Flash Attention kernels in quick build mode. ## Implementation Details - The `run_mha_fwd` helper now checks if `cache_batch_idx` is provided alongside `k_new` to determine if the split kernel should be forced. - New parameters are propagated through the call stack to the underlying Flash Attention kernels.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR updates the Flash Attention implementation in ONNX Runtime, syncing with newer kernel sources in https://github.com/Dao-AILab/flash-attention, and extending the internal API to support additional features required for advanced caching scenarios. It also aligns specific kernels with the official implementation.
Changes
onnxruntime/contrib_ops/cuda/bert/flash_attention/.mha_fwdandmha_fwd_kvcacheinflash_api.handflash_api.ccto accept two new optional parameters:cache_batch_idx: Indices to index into the KV cache (support for non-contiguous batch indices).leftpad_k: Support for left-padding in the key sequence.kInfinitydefinition inflash_fwd_kernel.h.<core/providers/cuda/shared_inc/cuda_call.h>inflash_fwd_launch_template.h.group_query_attention_impl.cuto align with the newmha_fwd_kvcachesignature.onnxruntime_providers_cpu.cmaketo update the exclusion list for Flash Attention kernels in quick build mode.Implementation Details
run_mha_fwdhelper now checks ifcache_batch_idxis provided alongsidek_newto determine if the split kernel should be forced.