Expose AQLayout as tunable parameter for CKTile blockscale 8-warp GEMM kernels#2487
Merged
Expose AQLayout as tunable parameter for CKTile blockscale 8-warp GEMM kernels#2487
Conversation
For 8-warp CKTile blockscale GEMM kernels, the host wrapper currently transposes x_scale from row-major to column-major at runtime before every kernel launch. This is unnecessary when the kernel can natively read row-major AQ data — the CK pipeline already supports both layouts via the AQLayout trait in TileGemmQuantTraits. Changes: - Add `AQRowMajor` bool field to TileKernelInstance (default False for backward compatibility). When True on an 8-warp config, the kernel uses RowMajor AQLayout and skips the host-side transpose. - Add `AQRowMajor` template parameter to CreateTileGemmConfig / TileGemmConfig and expose as AQRowMajor_v. - Derive `aq_col_major` from `eight_waves && !AQRowMajor_v` to select the AQ layout in GemmTraits and condition the host-side transpose. - Add kernel IDs 12/13 as RowMajor variants of existing 8-warp kernels 10/11 in kernels_list_95x, so the tuner benchmarks both options. - Update gen_instances_cktile.py to emit the new template argument. - Also fix hardcoded strides (stride_A=K, stride_B=K) to read from tensor metadata, matching the fix in the stride PR. Made-with: Cursor
Tests verify: - TileKernelInstance name encoding with _aqrm suffix - is_eight_warp property correctness - AQRowMajor variants exist in candidate kernel dict - Both ColumnMajor and RowMajor 8-warp kernels match PyTorch reference output - RowMajor variant works with padded (non-contiguous) weight tensors from vLLM's _maybe_pad_fp8_weight Made-with: Cursor
Remove non-8-warp kernel 12 (2x2x1) that incorrectly had AQRowMajor set. Correct test instances to use actual 8-warp config (4x2x1). Made-with: Cursor
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…wmajor_tunable_rebase Made-with: Cursor # Conflicts: # csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh
The AQRowMajor kernel variant is only in kernels_list_95x, not kernels_list_942, so the test fails on MI325X (gfx942) because no AQRowMajor entries exist in candidate_kernels_cktile_dict. Made-with: Cursor
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 3, 2026
…arp kernels Introduces AQRowMajor as a configurable parameter for 8-warp CKTile blockscale GEMM kernels, allowing the tuner to select between ColumnMajor and RowMajor AQ layouts. Resolved conflicts with our eight_warps naming vs upstream eight_waves. Source: ROCm#2487
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 3, 2026
…RowMajor optimization Source: ROCm#2487
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 3, 2026
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 3, 2026
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 3, 2026
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 3, 2026
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 3, 2026
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 3, 2026
Contributor
There was a problem hiding this comment.
Pull request overview
This PR makes activation-quantization scale layout (AQLayout) a tunable parameter for CKTile FP8 blockscale GEMM 8-warp kernels on gfx950, enabling an 8-warp RowMajor variant that can avoid the per-launch host-side x_scale transpose.
Changes:
- Add an
AQRowMajortemplate parameter toTileGemmConfigand use it to select RowMajor vs ColumnMajor AQ layout for 8-warp kernels. - Extend CKTile kernel instance generation to encode/propagate
AQRowMajor, registering a new 8-warp RowMajor variant (kernel ID 12). - Add a new op test validating name encoding, numerical accuracy vs reference, and padded-weight stride handling for the RowMajor AQ variant.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| op_tests/test_gemm_a8w8_blockscale_cktile_aq_rowmajor.py | Adds a gfx950-only test for RowMajor vs ColumnMajor AQ variants, including accuracy and padded weight stride coverage. |
| csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh | Introduces AQRowMajor in config and adjusts host-side x_scale handling + AQ layout selection logic. |
| csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py | Propagates AQRowMajor into generated C++ template instantiations. |
| csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_instance.py | Adds AQRowMajor to instance metadata, name suffix encoding, and registers the new RowMajor 8-warp kernel variant. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Contributor
Author
|
@valarLip Would you have time to review this PR? The CI passes with a recent merge with main. |
1 task
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 22, 2026
nholmber
added a commit
to nholmber/aiter
that referenced
this pull request
Apr 25, 2026
Tuned 1482 shapes (TP1/TP2/TP4) for Qwen/Qwen3-Next-80B-A3B-Instruct-FP8 on MI355X using CK + CK-TILE backends with splitK support. Depends on: - PR ROCm#2862 (CK bump for stride fix in CK-TILE blockscale) - PR ROCm#2541 (splitK support for CK/CK-TILE blockscale GEMMs) - PR ROCm#2487 (AQLayout tunable for CK-TILE blockscale 8-warp kernels)
valarLip
approved these changes
May 1, 2026
chun-wan
pushed a commit
that referenced
this pull request
May 4, 2026
…M kernels (#2487) * Add AQRowMajor tunable for CKTile blockscale 8-warp kernels For 8-warp CKTile blockscale GEMM kernels, the host wrapper currently transposes x_scale from row-major to column-major at runtime before every kernel launch. This is unnecessary when the kernel can natively read row-major AQ data — the CK pipeline already supports both layouts via the AQLayout trait in TileGemmQuantTraits. Changes: - Add `AQRowMajor` bool field to TileKernelInstance (default False for backward compatibility). When True on an 8-warp config, the kernel uses RowMajor AQLayout and skips the host-side transpose. - Add `AQRowMajor` template parameter to CreateTileGemmConfig / TileGemmConfig and expose as AQRowMajor_v. - Derive `aq_col_major` from `eight_waves && !AQRowMajor_v` to select the AQ layout in GemmTraits and condition the host-side transpose. - Add kernel IDs 12/13 as RowMajor variants of existing 8-warp kernels 10/11 in kernels_list_95x, so the tuner benchmarks both options. - Update gen_instances_cktile.py to emit the new template argument. - Also fix hardcoded strides (stride_A=K, stride_B=K) to read from tensor metadata, matching the fix in the stride PR. Made-with: Cursor * Add tests for CKTile blockscale FP8 GEMM AQRowMajor optimization Tests verify: - TileKernelInstance name encoding with _aqrm suffix - is_eight_warp property correctness - AQRowMajor variants exist in candidate kernel dict - Both ColumnMajor and RowMajor 8-warp kernels match PyTorch reference output - RowMajor variant works with padded (non-contiguous) weight tensors from vLLM's _maybe_pad_fp8_weight Made-with: Cursor * Fix AQRowMajor kernel variant and test assertions Remove non-8-warp kernel 12 (2x2x1) that incorrectly had AQRowMajor set. Correct test instances to use actual 8-warp config (4x2x1). Made-with: Cursor * Fix f-string Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * run black * Update gemm_a8w8_blockscale_cktile_instance.py formatting * Update gemm_a8w8_blockscale_cktile_instance.py formatting * Gate AQRowMajor test to gfx950 only The AQRowMajor kernel variant is only in kernels_list_95x, not kernels_list_942, so the test fails on MI325X (gfx942) because no AQRowMajor entries exist in candidate_kernels_cktile_dict. Made-with: Cursor * Update op_tests/test_gemm_a8w8_blockscale_cktile_aq_rowmajor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
For CKTile blockscale FP8 GEMM 8-warp kernels on gfx950, the activation quantization scale (
x_scale) is currently transposed from RowMajor to ColumnMajor on the host before every kernel launch. This PR makes theAQLayouta tunable parameter, allowing the kernel to natively readx_scalein RowMajor layout and skip the host-side transpose entirely.Changes
gemm_a8w8_blockscale_cktile_common.cuh: AddAQRowMajortemplate parameter toTileGemmConfig. When enabled, the 8-warp pipeline uses RowMajorAQLayoutand the host-sidex_scaletranspose + allocation is skipped.gemm_a8w8_blockscale_cktile_instance.py: AddAQRowMajorfield toTileKernelInstancewithis_eight_warpproperty. Register a new RowMajor variant (kernel ID 12) alongside the existing ColumnMajor 8-warp kernel (ID 11), so the tuner can evaluate both.gen_instances_cktile.py: PropagateAQRowMajorinto the generated C++ template instantiation.test_gemm_a8w8_blockscale_cktile_aq_rowmajor.py: New test covering instance name encoding, numerical accuracy (RowMajor vs ColumnMajor vs PyTorch reference), and padded weight stride handling (simulating vLLM's_maybe_pad_fp8_weight).Performance
Benchmarked on gfx950 using the tuning script (
gemm_a8w8_blockscale_tune.py --libtype cktile). Comparing the two 8-warp variants head-to-head:The tuner automatically selects the best variant per shape — no manual configuration needed.
Non-8-warp kernels
Non-8-warp kernels always use RowMajor
AQLayoutand are unaffected by this change. TheAQRowMajorflag is only meaningful for 8-warp configurations.Test plan
test_instance_names: Verifies_aqrmsuffix encoding andis_eight_warppropertytest_accuracy: RowMajor and ColumnMajor outputs match PyTorch FP32 reference across 4 shapestest_padded_weight_stride: RowMajor kernel handles non-contiguous (padded) weight tensors correctly