Skip to content

Expose AQLayout as tunable parameter for CKTile blockscale 8-warp GEMM kernels#2487

Merged
valarLip merged 16 commits intomainfrom
samremes/cktile_aq_rowmajor_tunable
May 1, 2026
Merged

Expose AQLayout as tunable parameter for CKTile blockscale 8-warp GEMM kernels#2487
valarLip merged 16 commits intomainfrom
samremes/cktile_aq_rowmajor_tunable

Conversation

@samremes
Copy link
Copy Markdown
Contributor

Summary

For CKTile blockscale FP8 GEMM 8-warp kernels on gfx950, the activation quantization scale (x_scale) is currently transposed from RowMajor to ColumnMajor on the host before every kernel launch. This PR makes the AQLayout a tunable parameter, allowing the kernel to natively read x_scale in RowMajor layout and skip the host-side transpose entirely.

Changes

  • gemm_a8w8_blockscale_cktile_common.cuh: Add AQRowMajor template parameter to TileGemmConfig. When enabled, the 8-warp pipeline uses RowMajor AQLayout and the host-side x_scale transpose + allocation is skipped.
  • gemm_a8w8_blockscale_cktile_instance.py: Add AQRowMajor field to TileKernelInstance with is_eight_warp property. Register a new RowMajor variant (kernel ID 12) alongside the existing ColumnMajor 8-warp kernel (ID 11), so the tuner can evaluate both.
  • gen_instances_cktile.py: Propagate AQRowMajor into the generated C++ template instantiation.
  • test_gemm_a8w8_blockscale_cktile_aq_rowmajor.py: New test covering instance name encoding, numerical accuracy (RowMajor vs ColumnMajor vs PyTorch reference), and padded weight stride handling (simulating vLLM's _maybe_pad_fp8_weight).

Performance

Benchmarked on gfx950 using the tuning script (gemm_a8w8_blockscale_tune.py --libtype cktile). Comparing the two 8-warp variants head-to-head:

  • Shapes with large N, small K (e.g. Nx7168, K=2048): RowMajor is 3–8% faster than ColumnMajor, as the transpose overhead is significant relative to compute time. The tuner selects the RowMajor kernel as the overall best for M=1024.
  • Shapes with large K (e.g. Nx2048, K=7168): ColumnMajor remains 8–21% faster, likely because the longer compute phase amortizes the transpose cost.

The tuner automatically selects the best variant per shape — no manual configuration needed.

Non-8-warp kernels

Non-8-warp kernels always use RowMajor AQLayout and are unaffected by this change. The AQRowMajor flag is only meaningful for 8-warp configurations.

Test plan

  • test_instance_names: Verifies _aqrm suffix encoding and is_eight_warp property
  • test_accuracy: RowMajor and ColumnMajor outputs match PyTorch FP32 reference across 4 shapes
  • test_padded_weight_stride: RowMajor kernel handles non-contiguous (padded) weight tensors correctly
  • The tuner picks up kernel 12 and selects it when it's the fastest

For 8-warp CKTile blockscale GEMM kernels, the host wrapper currently
transposes x_scale from row-major to column-major at runtime before
every kernel launch.  This is unnecessary when the kernel can natively
read row-major AQ data — the CK pipeline already supports both layouts
via the AQLayout trait in TileGemmQuantTraits.

Changes:
- Add `AQRowMajor` bool field to TileKernelInstance (default False for
  backward compatibility).  When True on an 8-warp config, the kernel
  uses RowMajor AQLayout and skips the host-side transpose.
- Add `AQRowMajor` template parameter to CreateTileGemmConfig /
  TileGemmConfig and expose as AQRowMajor_v.
- Derive `aq_col_major` from `eight_waves && !AQRowMajor_v` to select
  the AQ layout in GemmTraits and condition the host-side transpose.
- Add kernel IDs 12/13 as RowMajor variants of existing 8-warp kernels
  10/11 in kernels_list_95x, so the tuner benchmarks both options.
- Update gen_instances_cktile.py to emit the new template argument.
- Also fix hardcoded strides (stride_A=K, stride_B=K) to read from
  tensor metadata, matching the fix in the stride PR.

Made-with: Cursor
Tests verify:
- TileKernelInstance name encoding with _aqrm suffix
- is_eight_warp property correctness
- AQRowMajor variants exist in candidate kernel dict
- Both ColumnMajor and RowMajor 8-warp kernels match
  PyTorch reference output
- RowMajor variant works with padded (non-contiguous) weight
  tensors from vLLM's _maybe_pad_fp8_weight

Made-with: Cursor
Remove non-8-warp kernel 12 (2x2x1) that incorrectly had AQRowMajor
set. Correct test instances to use actual 8-warp config (4x2x1).

Made-with: Cursor
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2487 --add-label <label>

samremes and others added 4 commits March 26, 2026 15:57
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…wmajor_tunable_rebase

Made-with: Cursor

# Conflicts:
#	csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh
@samremes samremes marked this pull request as ready for review March 31, 2026 16:30
@samremes samremes requested a review from a team March 31, 2026 16:30
The AQRowMajor kernel variant is only in kernels_list_95x, not
kernels_list_942, so the test fails on MI325X (gfx942) because no
AQRowMajor entries exist in candidate_kernels_cktile_dict.

Made-with: Cursor
nholmber added a commit to nholmber/aiter that referenced this pull request Apr 3, 2026
…arp kernels

Introduces AQRowMajor as a configurable parameter for 8-warp CKTile
blockscale GEMM kernels, allowing the tuner to select between
ColumnMajor and RowMajor AQ layouts.

Resolved conflicts with our eight_warps naming vs upstream eight_waves.

Source: ROCm#2487
nholmber added a commit to nholmber/aiter that referenced this pull request Apr 3, 2026
nholmber added a commit to nholmber/aiter that referenced this pull request Apr 3, 2026
nholmber added a commit to nholmber/aiter that referenced this pull request Apr 3, 2026
nholmber added a commit to nholmber/aiter that referenced this pull request Apr 3, 2026
nholmber added a commit to nholmber/aiter that referenced this pull request Apr 3, 2026
nholmber added a commit to nholmber/aiter that referenced this pull request Apr 3, 2026
nholmber added a commit to nholmber/aiter that referenced this pull request Apr 3, 2026
@samremes samremes requested a review from Copilot April 8, 2026 12:35
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR makes activation-quantization scale layout (AQLayout) a tunable parameter for CKTile FP8 blockscale GEMM 8-warp kernels on gfx950, enabling an 8-warp RowMajor variant that can avoid the per-launch host-side x_scale transpose.

Changes:

  • Add an AQRowMajor template parameter to TileGemmConfig and use it to select RowMajor vs ColumnMajor AQ layout for 8-warp kernels.
  • Extend CKTile kernel instance generation to encode/propagate AQRowMajor, registering a new 8-warp RowMajor variant (kernel ID 12).
  • Add a new op test validating name encoding, numerical accuracy vs reference, and padded-weight stride handling for the RowMajor AQ variant.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
op_tests/test_gemm_a8w8_blockscale_cktile_aq_rowmajor.py Adds a gfx950-only test for RowMajor vs ColumnMajor AQ variants, including accuracy and padded weight stride coverage.
csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_cktile_common.cuh Introduces AQRowMajor in config and adjusts host-side x_scale handling + AQ layout selection logic.
csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py Propagates AQRowMajor into generated C++ template instantiations.
csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile_instance.py Adds AQRowMajor to instance metadata, name suffix encoding, and registers the new RowMajor 8-warp kernel variant.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread op_tests/test_gemm_a8w8_blockscale_cktile_aq_rowmajor.py Outdated
Comment thread op_tests/test_gemm_a8w8_blockscale_cktile_aq_rowmajor.py
@samremes
Copy link
Copy Markdown
Contributor Author

@valarLip Would you have time to review this PR? The CI passes with a recent merge with main.

nholmber added a commit to nholmber/aiter that referenced this pull request Apr 22, 2026
PR ROCm#2862's CK bump (cbfb3e242) lacks the ABQuantGrouped/GemmTraits
APIs needed by PRs ROCm#2541 and ROCm#2487. Update to 020b6f435 which has
both the stride fix and the required CK-TILE blockscale APIs.
nholmber added a commit to nholmber/aiter that referenced this pull request Apr 25, 2026
Tuned 1482 shapes (TP1/TP2/TP4) for Qwen/Qwen3-Next-80B-A3B-Instruct-FP8
on MI355X using CK + CK-TILE backends with splitK support.

Depends on:
- PR ROCm#2862 (CK bump for stride fix in CK-TILE blockscale)
- PR ROCm#2541 (splitK support for CK/CK-TILE blockscale GEMMs)
- PR ROCm#2487 (AQLayout tunable for CK-TILE blockscale 8-warp kernels)
@valarLip valarLip merged commit 23b04cc into main May 1, 2026
51 of 55 checks passed
@valarLip valarLip deleted the samremes/cktile_aq_rowmajor_tunable branch May 1, 2026 14:03
chun-wan pushed a commit that referenced this pull request May 4, 2026
…M kernels (#2487)

* Add AQRowMajor tunable for CKTile blockscale 8-warp kernels

For 8-warp CKTile blockscale GEMM kernels, the host wrapper currently
transposes x_scale from row-major to column-major at runtime before
every kernel launch.  This is unnecessary when the kernel can natively
read row-major AQ data — the CK pipeline already supports both layouts
via the AQLayout trait in TileGemmQuantTraits.

Changes:
- Add `AQRowMajor` bool field to TileKernelInstance (default False for
  backward compatibility).  When True on an 8-warp config, the kernel
  uses RowMajor AQLayout and skips the host-side transpose.
- Add `AQRowMajor` template parameter to CreateTileGemmConfig /
  TileGemmConfig and expose as AQRowMajor_v.
- Derive `aq_col_major` from `eight_waves && !AQRowMajor_v` to select
  the AQ layout in GemmTraits and condition the host-side transpose.
- Add kernel IDs 12/13 as RowMajor variants of existing 8-warp kernels
  10/11 in kernels_list_95x, so the tuner benchmarks both options.
- Update gen_instances_cktile.py to emit the new template argument.
- Also fix hardcoded strides (stride_A=K, stride_B=K) to read from
  tensor metadata, matching the fix in the stride PR.

Made-with: Cursor

* Add tests for CKTile blockscale FP8 GEMM AQRowMajor optimization

Tests verify:
- TileKernelInstance name encoding with _aqrm suffix
- is_eight_warp property correctness
- AQRowMajor variants exist in candidate kernel dict
- Both ColumnMajor and RowMajor 8-warp kernels match
  PyTorch reference output
- RowMajor variant works with padded (non-contiguous) weight
  tensors from vLLM's _maybe_pad_fp8_weight

Made-with: Cursor

* Fix AQRowMajor kernel variant and test assertions

Remove non-8-warp kernel 12 (2x2x1) that incorrectly had AQRowMajor
set. Correct test instances to use actual 8-warp config (4x2x1).

Made-with: Cursor

* Fix f-string

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* run black

* Update gemm_a8w8_blockscale_cktile_instance.py formatting

* Update gemm_a8w8_blockscale_cktile_instance.py formatting

* Gate AQRowMajor test to gfx950 only

The AQRowMajor kernel variant is only in kernels_list_95x, not
kernels_list_942, so the test fails on MI325X (gfx942) because no
AQRowMajor entries exist in candidate_kernels_cktile_dict.

Made-with: Cursor

* Update op_tests/test_gemm_a8w8_blockscale_cktile_aq_rowmajor.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants