Skip to content

[CUDA] Speed up flash attention build#26924

Merged
tianleiwu merged 1 commit intomainfrom
tlwu/flash_att_quick_build
Jan 7, 2026
Merged

[CUDA] Speed up flash attention build#26924
tianleiwu merged 1 commit intomainfrom
tlwu/flash_att_quick_build

Conversation

@tianleiwu
Copy link
Contributor

Summary

This pull request aims to significantly reduce the build time for Flash Attention by removing support for less common head dimensions (160 and 224).

It also includes a build option for quick build --cmake_extra_defines onnxruntime_QUICK_BUILD=ON, which will only build flash attention kernel for float16 and head dimension=128. That could speed up development.

Key Changes

1. Flash Attention Build Optimization

  • Removed Head Dimensions: Deleted source files and kernel instantiations for head dimensions 160 and 224 (both FP16 and BF16). These dimensions are less frequently used, and removing them reduces the number of kernels to be compiled, thereby speeding up the build process.
  • Updated Dispatch Logic: Modified static_switch.h and flash_api.h to remove the dispatch cases for kHeadDim = 160 and kHeadDim = 224.

2. Test Enhancements

  • GQA Tests: Updated onnxruntime/test/python/transformers/test_gqa.py to detect whether it is quick build package. If it is, only test supported data type (float16) and head dimension (128 only) for flash attention, and use has_flash_attention(bf16=True) when checking for Flash Attention availability in BF16 tests. This ensures that tests are skipped appropriately if BF16 kernels are not compiled/available.

Impact

  • Build Time: Faster compilation of the CUDA provider due to fewer Flash Attention kernels.
  • Functionality: Head dimensions 160 and 224 are no longer supported for Flash Attention. Models using these specific head dimensions will fall back to next supported head dimension like 192 or 256.

Verification

  • Validated that the build completes successfully with the reduced kernel set.
  • test_gqa.py should pass or skip correctly based on hardware support.
  • Build onnxruntime-gpu package with --cmake_extra_defines onnxruntime_QUICK_BUILD=ON option, and the build info has "quick-build=1", like the following python script:
import onnxruntime
print(onnxruntime.get_build_info())

The output is like

ORT Build Info: git-branch=main, git-commit-id=ecf164a945, quick-build=1, build type=Release

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request optimizes Flash Attention build times by removing less common head dimensions (160 and 224) and introducing a quick build mode that limits compilation to FP16 kernels with head dimension 128.

Key Changes:

  • Removed Flash Attention kernel support for head dimensions 160 and 224 (both FP16 and BF16)
  • Added onnxruntime_QUICK_BUILD CMake option to build only FP16 with head dimension 128
  • Updated the dispatch logic in static_switch.h to skip removed head dimensions
  • Modified tests to detect and handle quick build mode appropriately

Reviewed changes

Copilot reviewed 19 out of 19 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
test_gqa.py Detects quick build mode from build info and restricts test parameters to FP16 and head_size=128 when enabled
paged_attention.cc Updated is_supported call to templated version for type checking
packed_multihead_attention.cc Updated is_supported call to templated version for type checking
multihead_attention.cc Updated is_supported call to templated version for type checking
group_query_attention.cc Updated is_supported call to templated version for type checking
attention.cc Updated is_supported call to templated version for type checking
static_switch.h Added ORT_QUICK_BUILD macros to limit FP16_SWITCH and HEADDIM_SWITCH; removed head dim 160/224 from HEADDIM_SWITCH
flash_api.h Added templated is_supported function for quick build type checking; minor comment corrections
flash_fwd_launch_template.h Removed commented-out code for cleaner implementation
flash_fwd_*hdim160*.cu Deleted kernel instantiation files for head dimension 160
flash_fwd_*hdim224*.cu Deleted kernel instantiation files for head dimension 224
onnxruntime_providers_cpu.cmake Added quick build filtering logic to exclude non-128 and BF16 kernels
CMakeLists.txt Added onnxruntime_QUICK_BUILD option and build info string generation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@tianleiwu tianleiwu merged commit 541d5da into main Jan 7, 2026
97 checks passed
@tianleiwu tianleiwu deleted the tlwu/flash_att_quick_build branch January 7, 2026 22:43
alex-spacemit pushed a commit to spacemit-com/onnxruntime that referenced this pull request Jan 20, 2026
## Summary
This pull request aims to significantly reduce the build time for Flash
Attention by removing support for less common head dimensions (160 and
224).

It also includes a build option for quick build `--cmake_extra_defines
onnxruntime_QUICK_BUILD=ON`, which will only build flash attention
kernel for float16 and head dimension=128. That could speed up
development.

## Key Changes

### 1. Flash Attention Build Optimization
- **Removed Head Dimensions:** Deleted source files and kernel
instantiations for head dimensions **160** and **224** (both FP16 and
BF16). These dimensions are less frequently used, and removing them
reduces the number of kernels to be compiled, thereby speeding up the
build process.
- **Updated Dispatch Logic:** Modified `static_switch.h` and
`flash_api.h` to remove the dispatch cases for `kHeadDim = 160` and
`kHeadDim = 224`.

### 2. Test Enhancements
- **GQA Tests:** Updated
`onnxruntime/test/python/transformers/test_gqa.py` to detect whether it
is quick build package. If it is, only test supported data type
(float16) and head dimension (128 only) for flash attention, and use
`has_flash_attention(bf16=True)` when checking for Flash Attention
availability in BF16 tests. This ensures that tests are skipped
appropriately if BF16 kernels are not compiled/available.

## Impact
- **Build Time:** Faster compilation of the CUDA provider due to fewer
Flash Attention kernels.
- **Functionality:** Head dimensions 160 and 224 are no longer supported
for Flash Attention. Models using these specific head dimensions will
fall back to next supported head dimension like 192 or 256.

## Verification
- Validated that the build completes successfully with the reduced
kernel set.
- `test_gqa.py` should pass or skip correctly based on hardware support.
- Build onnxruntime-gpu package with `--cmake_extra_defines
onnxruntime_QUICK_BUILD=ON` option, and the build info has
"quick-build=1", like the following python script:

```python
import onnxruntime
print(onnxruntime.get_build_info())
```

The output is like
```
ORT Build Info: git-branch=main, git-commit-id=ecf164a945, quick-build=1, build type=Release
```
alex-spacemit pushed a commit to spacemit-com/onnxruntime that referenced this pull request Jan 27, 2026
This pull request aims to significantly reduce the build time for Flash
Attention by removing support for less common head dimensions (160 and
224).

It also includes a build option for quick build `--cmake_extra_defines
onnxruntime_QUICK_BUILD=ON`, which will only build flash attention
kernel for float16 and head dimension=128. That could speed up
development.

- **Removed Head Dimensions:** Deleted source files and kernel
instantiations for head dimensions **160** and **224** (both FP16 and
BF16). These dimensions are less frequently used, and removing them
reduces the number of kernels to be compiled, thereby speeding up the
build process.
- **Updated Dispatch Logic:** Modified `static_switch.h` and
`flash_api.h` to remove the dispatch cases for `kHeadDim = 160` and
`kHeadDim = 224`.

- **GQA Tests:** Updated
`onnxruntime/test/python/transformers/test_gqa.py` to detect whether it
is quick build package. If it is, only test supported data type
(float16) and head dimension (128 only) for flash attention, and use
`has_flash_attention(bf16=True)` when checking for Flash Attention
availability in BF16 tests. This ensures that tests are skipped
appropriately if BF16 kernels are not compiled/available.

- **Build Time:** Faster compilation of the CUDA provider due to fewer
Flash Attention kernels.
- **Functionality:** Head dimensions 160 and 224 are no longer supported
for Flash Attention. Models using these specific head dimensions will
fall back to next supported head dimension like 192 or 256.

- Validated that the build completes successfully with the reduced
kernel set.
- `test_gqa.py` should pass or skip correctly based on hardware support.
- Build onnxruntime-gpu package with `--cmake_extra_defines
onnxruntime_QUICK_BUILD=ON` option, and the build info has
"quick-build=1", like the following python script:

```python
import onnxruntime
print(onnxruntime.get_build_info())
```

The output is like
```
ORT Build Info: git-branch=main, git-commit-id=ecf164a945, quick-build=1, build type=Release
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants