update #2944
Merged
demonsan merged 119 commits intoblyu/opus_gemmfrom Apr 29, 2026
Merged
Conversation
* Update runner-config.yml * CI: surface runner-config mapping in AMD CI job monitor Load GPU architecture and count from runner-config.yml so the runner fleet summary shows the configured inventory for each label. Trigger the monitor workflow when runner mappings change and install PyYAML for the runner report job.
Co-authored-by: gyohuangxin <42127654+gyohuangxin@users.noreply.github.com> Co-authored-by: Xin Huang <Xin.Huang@amd.com>
Add a dedicated GitHub Actions workflow for op_tests/opus so OPUS validation runs independently on MI35X and MI325 runners without being mixed into the main Aiter test shards.
Co-authored-by: solin <bingzhou@amd.com>
Split main-branch concurrency by event type so scheduled runs do not block push-triggered validation when a long queued job keeps the nightly workflow open.
* CI: Enable Deepseek ATOM tests on MI35X * CI: Use /models cache for MI35X ATOM DeepSeek test Route the MI35X DeepSeek job to the runner-local /models cache so it avoids downloading into /run, and make the output artifact name unique now that two DeepSeek variants run in the same workflow. * CI: Mount /models into MI35X ATOM test container Pass the runner's shared /models cache into atom_aiter_test so MI35X DeepSeek jobs can use the mounted model path.
* replace ck with opus * fix compile issue * fix waterfall and use buffer inst for lse. * Replace ck with opus for mla metadata. * Add dim=512 fp32 case for wave32
* docs: add ISA-level kernel optimization guide using LLVM tools Step-by-step guide covering the full LLVM-based workflow for inspecting, modifying, and recompiling AITER GPU kernel ISA: disassemble, extract reassemblable .s, round-trip recompile with binary-identical .text verification, and profile with rocprofv3. Includes Python extraction script handling branch label word-offset addressing, llvm-objcopy section swap for preserving kernel metadata, and rocprofv3 kernel-trace + ATT profiling instructions. * docs: add ISA optimization code examples and Dockerfile Runnable companion to the ISA kernel optimization guide: - extract_asm.py: standalone ASM extraction with CLI interface - analyze_kernel.py: instruction mix analysis and rocprofv3 profile parser - roundtrip.sh: end-to-end disassemble/extract/recompile/verify script - Dockerfile: ROCm 7.2.1 dev environment with all tools pre-installed including ATT trace decoder built from source * style: fix black formatting and ruff lint in ISA optimization examples - Rename loop var 'l' to 'ln' to fix E741 (ambiguous variable name) - Remove extraneous f-prefix on strings without placeholders (F541) - Apply black auto-formatting * style: fix black formatting for CI compatibility - Add blank line between module docstring and imports (E302) - Collapse multiline f-string call arguments --------- Co-authored-by: Peng Sun <pensun@Pengs-MacBook-Pro.local>
* Add run_config/compare support to GemmTuner (bf16) - Add config_env_name, _clear_op_caches(), and run_config() to GemmTuner so --run_config and --compare flags work for bf16 GEMM tuning - Update bubbly-exploring-turtle.md plan doc to reflect the full implementation including --compare, config_env_name, cache clearing, and post-tune config switching architecture Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Add --run_config and --compare benchmark support to all tuners Add infrastructure in base_tuner.py for production operator benchmarking: - --run_config: benchmark only, no tuning - --compare: pre-tune benchmark, tune, post-tune benchmark with comparison table - Config env switching and cache clearing for post-tune benchmarks Implement run_config() and _clear_op_caches() in all CK-based tuners: gemm_a8w8, gemm_a8w8_bpreshuffle, gemm_a8w8_blockscale, gemm_a8w8_blockscale_bpreshuffle, gemm_a4w4_blockscale, gemm_moe_2stages, batched_gemm_a8w8, batched_gemm_bf16 * Revert unintended composable_kernel submodule change * Fix review comments and remove intermediate plan docs - Save/restore AITER_REBUILD original value instead of hardcoding 0 - Use defensive strip() for mixed-type object columns in _read_csv - Remove docs/bubbly-exploring-turtle.md and docs/run_config_benchmark.md (consolidated into csrc/.claude/add_run_config_to_tuner.md) * update ref rtol,atol * Fix tuner cache invalidation, run_config preshuffle, and compare workflow - Fix _clear_op_caches for all tuners: properly clear lru_cache and internal dict/attribute caches (a4w4, a8w8 variants, fmoe) so post-tune benchmarks use freshly tuned configs instead of stale ones. - Fix fmoe run_config: preshuffle weights before calling fused_moe to match production layout (tuner always tunes with bpreshuffle=True), preventing preshuffle_on/off module mismatch and 99%+ error. - Add defensive warning in fused_moe get_2stage_cfgs when tuned config is found but is_shuffled=False. - Fix run_config to read shapes from tuned CSV and set config env var. - Fix --compare workflow: run post-tune benchmark before tune_summary to avoid summary errors blocking verification. - Fix base_tuner _set_config_env_for_run_config return value. - Use print instead of logger.info for benchmark tables. * fix format * fix format * update readme * fix lint error * fix lint * update csv only when perf improves * format * fix lint * revert format for some files * clarify compare and gated update flow Make --compare keep a candidate csv and require --update_improved before writing back tuned results so the CLI stays explicit and easier to extend. Made-with: Cursor * fix flydsl GemmTuner review issues Trigger FlyDSL package validation before importing tuning kernels and align the FlyDSL bias cast order with the runtime path to avoid unintended dtype promotion. Made-with: Cursor * update * revert claude md * update shape_grouped * fix format * fix bug * fix lint error --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
* Move swiglu to a util file + add optional residual flag * refactor reduce and make it compatible with >65k tokens * Update aiter/ops/triton/_triton_kernels/moe/reduce.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: Lukasz Burzawa <lukasz.burzawa@amd.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…#2744) When `use_asm_v3` is `false`, `fmha_fwd_v3()` correctly returns `-1` to fall back to the CK path, but it also emits a misleading "unsupported condition in fwd_v3!!!" warning. This is not an unsupported condition — the caller intentionally opted out of v3. Separate the `use_asm_v3` check into an early return without a warning, so the `AITER_LOG_WARNING` only fires for genuinely unsupported parameter combinations (wrong head dims, unsupported dtypes, bias, dropout, wrong arch). Made-with: Cursor
* feat(aot): add MoE FlyDSL AOT pre-compilation module Co-Authored-By: Claude Opus 4 <noreply@anthropic.com> Signed-off-by: zhiding512 <zhimding@amd.com> * refactor(aot): remove --stage flag from MoE AOT module Co-Authored-By: Claude Opus 4 <noreply@anthropic.com> Signed-off-by: zhiding512 <zhimding@amd.com> * reformat * feat(aot): integrate FlyDSL MoE AOT precompilation into setup.py Move moe.py into aiter/aot/flydsl/, support multiple CSV configs, simplify compile_one_config to use COMPILE_ONLY=1 env var, and add MoE AOT pre-compilation step to the package build in setup.py. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com> Signed-off-by: zhimding <zhiming.ding@amd.com> * feat(aot): FlyDSL MoE AOT with COMPILE_ONLY dummy-tensor precompilation Rework FlyDSL MoE AOT to use COMPILE_ONLY=1 with dummy tensors instead of run_kernel, removing all HIP op dependencies (moe_sorting, shuffle_weight, etc.) from the precompilation path. Changes: - Replace _run_kernel with _precompile_to_cache using torch.zeros dummy tensors and COMPILE_ONLY=1 for pkl cache generation - Add sys.modules bridging in setup.py so aiter.jit.core reuses the same module instance loaded via sys.path - Auto-detect bundled flydsl_cache in aiter/__init__.py and set FLYDSL_RUNTIME_CACHE_DIR - Add KeyError to aiter/__init__.py exception handler for robustness - Support multiple CSV configs (dsv3 + kimik2) - Remove run_kernel parameter and test_bad_tile logic Signed-off-by: zhimding <zhiming.ding@amd.com> * update flydsl * update flydsl * adapt hgemm * fix black * add flydsl gemm aot precompile support --------- Signed-off-by: zhiding512 <zhimding@amd.com> Signed-off-by: zhimding <zhiming.ding@amd.com> Co-authored-by: Claude Opus 4 <noreply@anthropic.com>
* gather support qk_nope_head_dim != v_head_dim * fix 192 pad
…2733) * feat: add/retune BF16 GEMM configs with FlyDSL backend for 6 models Tuned on MI355X (gfx950) with all backends competing (ASM, hipBLASLt, Triton, FlyDSL). New tuned configs for Llama 70B, Llama 405B, Qwen 32B. Re-tuned existing configs for GPT-OSS, DSV3, Kimi-K2 to include FlyDSL. Backend wins across 708 total shapes: - hipBLASLt: 472 (66.7%) - ASM: 131 (18.5%) - FlyDSL: 70 (9.9%) - Triton: 7 (1.0%) - Mixed/other: 28 (4.0%) * feat: retune BF16 GEMM without hipBLASLt, add GLM-5 and 3 new models Re-tuned all BF16 GEMM configs on MI355X (gfx950) with --libtype asm,triton,flydsl (no hipBLASLt). Added GLM-5 (88 shapes from CI log) and new configs for Llama 70B, Llama 405B, Qwen 32B. Backend wins across 796 total shapes (7 models): - ASM: 437 (54.9%) - FlyDSL: 224 (28.1%) - Triton: 135 (17.0%) Per-model breakdown: - GPT-OSS (57): asm=54, triton=3 (bias=True, no FlyDSL support) - DSV3 (58): flydsl=22, triton=18, asm=18 - Kimi-K2 (125): asm=77, flydsl=46, triton=2 - GLM-5 (88): asm=42, flydsl=30, triton=16 - Llama 70B (156): asm=84, flydsl=49, triton=23 - Llama 405B (156): asm=89, flydsl=43, triton=24 - Qwen 32B (156): asm=73, triton=49, flydsl=34 Tuning time without hipBLASLt: 4h total (long pole: 405B @ 4h) vs with hipBLASLt: 10h+ total (long pole: 405B @ 8h+)
* Hoist inspect.signature/typing.get_type_hints out of per-call ctypes dispatch These two introspection calls were recomputed on every invocation of the ctypes caller closure (~79µs + ~91µs per call). Since the decorated function's signature and type hints are immutable, compute them once at decoration time and capture via closure. Made-with: Cursor * update ruff format * update black format --------- Co-authored-by: amd-ruitang3 <rui.tang2@amd.com>
…ise (#2732) * Handle FlyDSL LDS limits and candidate failures Use shared-memory-per-block queries to keep FlyDSL LDS checks architecture-aware, and surface candidate failures as concise runtime warnings so tuning can continue without noisy tracebacks. Made-with: Cursor * Keep tuner topk local per shape Avoid mutating the shared topk value while post-processing one shape so later shape groups keep the intended candidate limit. Made-with: Cursor * fix lint * Cache FlyDSL shared memory queries Avoid repeated device property lookups while validating FlyDSL kernel configs by caching the default device selection and shared-memory-per-block queries. Made-with: Cursor * fix lint * Update aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_common.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * parse kernel name to select flydsl kernel * fix black format error * refine --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: solin <bingzhou@amd.com>
Co-authored-by: Sergey Solo <ssolovye@amd.com>
* Make AiterAsmKernel load hsaco on each GPU it is used on * Replace unsafe uses of std::unordered_map with SynchronizedCache
…2723) * Fix Triton MoE GEMM shared memory exhaustion - Reduce num_stages in kernel configs - Lowered LDS usage to avoid shared memory OOR - Fix triton.runtime.errors.OutOfResources errors in MoE GEMM kernels * Fix: set num_stages=1 on gfx950 using get_arch() conditionally for gfx950 to ensure no bottlenecks for gfx942 * Add determinism for fused mul add test * Format fused mul add test with black
The type annotation bool was incorrect for moe_sorting_dispatch_policy, which accepts int values. The @torch_compile_guard decorator uses these annotations to generate PyTorch custom op schemas; with bool, PyTorch schema enforcement casts any value to bool, so dispatch_policy=2 becomes bool(2)=True (1), silently losing the intended policy. Using int allows callers to set dispatch_policy=2 correctly. Fixes: #2576 Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Tres Popp <tres.popp@amd.com>
* fix moe splitk aot and jit * split moe aot to serveral libs base on tuned_moe configs * update copyrights * fix typo * test shuffle as default and fix moe split jit
…ed RHS compatibility (#2704) * Update quant.py * Refactor per_1x32_f4_quant function signature * Fix function definition formatting in quant.py * Refactor per_1x32_f4_quant_for_dot_scaled definition * Restore semantic.py to match main branch
…t) (#2729) * Add bf16 MLA decode kernel for gqa_ratio=64, qseqlen=1 (non-persistent) * black mla.py * fix short kv len split error * Support final LSE output in non-persistent MLA reduce kernel * black mla.py * ruff error --------- Co-authored-by: minmengdie <memin@amd.com>
--------- Co-authored-by: coderfeli <felix.li@amd.com>
* Update kimik2 FP4 tuned fMoE config with 256 CU tuning results Made-with: Cursor * Fix kernel names for kimik2 FP4 token=8 expert=385 topk=9 config entry --------- Co-authored-by: okorzh-amd <okorzh-amd@users.noreply.github.com>
* fix(pa): port sliding window mtp fixes to main Reapply the sliding-window MTP decode, PS reduce, and KV_BLOCK_SIZE=1024 fixes on top of the latest main so the change can be reviewed and merged from a branch with shared history. Made-with: Cursor * fix fallback * fix(pa): Use FlyDSL PS reduce for sliding-window MTP Signed-off-by: fsx950223 <fsx950223@outlook.com> Made-with: Cursor * fix(pa): Move Gluon decode imports to module top Signed-off-by: fsx950223 <fsx950223@outlook.com> Made-with: Cursor * fix regression --------- Signed-off-by: fsx950223 <fsx950223@outlook.com>
… network issues (#2895)
Keep MI35X as the default Triton PR path, add ci:triton-300x to start extra MI300X jobs on PRs, and run both architectures on main. Update the PR welcome comment to document the new label and main-branch behavior.
* add optimized prefill gdn kernels for qwen3_5 * refine code style * add ssm_state vk_layout kernels for vllm support * add default turnoff for triton autotune
…#2262) * Introduce asm fmoe kernels that do not require bf16->fp8 quantization * Update quantization division precision to be closer to IEEE correctness * Add transpose_scale flag * Update kernels with fixed s_waitcnt * Add 16x128 kernel merged with quantization * Remove legacy code * Remove one more redundant line * Revert changes to sub_X_cnt calulations for fmoe asm * Update the tuner to support x_bf16 kernel * Update pandas interface * Fix formatting * Generate config with the new tuner * Update base tuner to support merging csv files without strict column matching * Fix formatting * Fix core.py to handle columns merge * Fix formatting --------- Co-authored-by: Sergey Solo <ssolovye@amd.com>
#2893) * test: expand test_batch_prefill_large_kvcache for >4GB KV cache overflow Rewrite test_batch_prefill_large_kvcache to validate the per-tile SRD rebase fix for >4GB KV caches across all page sizes, dtypes, and attention configurations: - Add page_size=1 and 16 (page_size < kN0, exercises rebase path) - Add GQA (16, 8) in addition to MHA (8, 8) - Add causal masking with CK-compatible attn_mask for SDPA reference - Use full KV cache (4.5GB) with pages spanning the overflow boundary - Use torch SDPA as reference (memory-efficient backend, no score matrix materialization) - Add scatter_pages parameter (False only; True for future global_load_lds flat addressing) - Add GPU memory check to skip configs that exceed HBM capacity Test matrix: 24 cases (3 page_sizes × 2 dtypes × 2 causal × 2 GQA × 1 scatter) * test: add GPU sync after CK kernel in large_kvcache test Add torch.cuda.synchronize() after CK kernel launch in test_batch_prefill_large_kvcache to ensure all async GPU work completes before memory is freed between tests. Without this sync, repeated allocate/free cycles of large KV cache buffers (~20GB) with mixed dtype (bf16→fp8) can trigger GPU page faults when the HIP memory allocator reuses virtual addresses that are still referenced by pending async GPU work. The fault manifests as VM_L2_PROTECTION_FAULT at address 0x0 (NULL), causing GPU reset and kernel soft lockup. * feat(fmha): runtime dispatch for >4GB KV cache in batch prefill Add use_64bit_load to batch prefill traits and runtime overflow detection. When page_block_size < 128 and max_page_byte_offset > INT32_MAX, dispatch to the flat 64-bit load kernel variant for correctness. Also add vectorized KV layout coverage to test_batch_prefill_large_kvcache. * fix: remove unused k_vector_size variable in large_kvcache test * fix(mha): improve batch_prefill TORCH_CHECK error message for >4GB KV cache Include page_size, num_pages, and dtype in the error message when kernel dispatch fails. Add hint about CDNA3+ GPU requirement when KV cache exceeds 4GB with page_size < 128. * test: update scatter_pages comment in large_kvcache test The comment incorrectly stated scatter_pages=True was "expected to FAIL". This is no longer true — the flat 64-bit load path handles scattered pages correctly. Update to describe the test's purpose instead. * fix(mha): widen batch_prefill 64-bit threshold to total KV bytes The previous check used (num_total_pages - 1) * batch_stride * element_size which measures the last-page base offset, missing within-page offsets and producing an off-by-one at exactly INT32_MAX (the largest representable SRD voffset). Switch to total KV cache footprint (num_total_pages * batch_stride * element_size > INT32_MAX) so within-page reads on the last page are covered, and drop the redundant num_total_pages > 1 guard since single-page configs trivially fit in 32 bits. Also unify wording: 4GB → 2GB (INT32_MAX byte offset for SRD voffset), matching CK's TwoGB convention. The actual hardware bound has always been 2GB; the prior comments were imprecise. Found during batch prefill template dispatch review. * docs(mha): unify >2GB wording in batch_prefill error and test The 4GB number in the TORCH_CHECK error message and the test comment was imprecise — the actual SRD voffset bound is 2GB (INT32_MAX). Update both to match the threshold check and CK's TwoGB convention. Found during batch prefill template dispatch review. * refactor(mha): drop wrapper-side use_64bit_load; let CK dispatcher decide The wrapper hardcoded kN0_min = 128 to compute the >2GB KV cache predicate, which leaked CK tile config into aiter and would silently break if a new arm with bn0 != 128 were added. The CK auto-generated dispatcher now decides per-arm using its own compile-time bn0 and per-dtype kElementBytes, so the wrapper just forwards args. Remove the `use_64bit_load` runtime field from `mha_batch_prefill_traits`, the parameter from `get_mha_batch_prefill_traits()`, and the entire predicate computation block from the dispatcher call site. Bumps CK submodule to pull in the matching codegen change. * chore(mha): bump CK + update wrapper wording for kUseGlobalLoad rename Bumps 3rdparty/composable_kernel to dd8d293ea (refactor(fmha): batch prefill review polish — assert helper + setter guards) which builds on the prior 99a3ca9af kUseGlobalLoad rename. Wrapper-side updates to match: * csrc/cpp_itfs/mha_fwd_batch_prefill.cu: rename "64-bit-load" wording in the per-arm dispatcher comment to "kUseGlobalLoad" so the wrapper comment matches the CK-side identifier. Also drops the trailing `false /* skip_min_seqlen_q */` argument from the get_mha_batch_prefill_traits call to match the upstream CK API signature change. * csrc/py_itfs_ck/mha_batch_prefill_kernels.cu: change the >2GB error message from "page_size < 128" to "page_size < kN0" so the diagnostic tracks the tile-size constant rather than a magic number. * op_tests/test_batch_prefill.py (test_batch_prefill_large_kvcache): three documentation enhancements with no behavior change — - explain why qo_len caps at 128 (causal) / 1024 (non-causal): the causal cap is a math-backend cliff for the SDPA reference, not a kernel limit; - explain that the +256 padding on kv_page_indices is a batch_prefill ABI requirement (kernel may speculatively read up to bn0=256 entries past the last valid page index); - expand the torch.cuda.synchronize comment to call out the misattribution failure mode and GPU-reset cascade risk. * test(fmha): parametrize test_batch_prefill_large_kvcache over batch_size {1, 4} Adds multi-batch coverage to the >2GB KV cache regression test. The previous single-batch coverage left the kernel's per-sequence SRD rebase path unexercised: with cu_seqlens_q=[0, qo_len] and kv_indptr= [0, num_blocks], the kernel never walks the indptr to reposition K/V SRDs across batch boundaries. After the kUseGlobalLoad rename and the new positive static_assert(kUseGlobalLoad_) calls in update_physical_pages and set_page_stride_elements, we want a regression that catches any boundary-crossing SRD bug -- the failure mode no single-batch test can detect (one batch correct, others wrong). batch_size=4 partitions the >2GB page pool across 4 sequences (last sequence absorbs the remainder), exercising 3 cross-batch SRD transitions. The SDPA reference is computed per-batch and concatenated; per-iteration free + empty_cache keeps peak memory at one batch's worth. Verified on: - gfx950 (smci355-gfx950, MI355X): 160 passed, 32 skipped - gfx942 (smc300x-clt, MI308X): 160 passed, 32 skipped Skips are the existing vectorized + page_size=1 incompatibility (3D tensor layout), now 16 per batch_size value. --------- Co-authored-by: Xin Huang <Xin.Huang@amd.com>
--------- Co-authored-by: zhuyuhua-v <yuhzhu@amd.com>
* mHC: Optimize mhc_pre performance in small M (add tileN=16 on mhc_pre_gemm_sqrsum && add splitk on mhc_pre_big_fuse) * update test
* fix(mha_bwd): pass independent strides for do in _bwd_preprocess - adding independent `stride_do_b/h/m/k` parameters to `_bwd_preprocess` and addressing dO with them; - passing `*do_strides` (already computed in the wrapper) alongside `*o_strides` when launching the preprocess kernel. * pass independent strides for flash_attn_fused_backward kernel * fix typo
* [Triton]: Add MoE a16w4 * Fix Black issue * Removed x scales and improved arch checks in unit test * Added benchmark * Fixed PR review issues * PR review comments fixes --------- Co-authored-by: Rahul Batra <rahbatra@amd.com>
#2825) --------- Signed-off-by: zhimding <zhiming.ding@amd.com> Signed-off-by: zhimding <zhimding@amd.com>
* CI: log in before starting OPUS test container Authenticate with the Docker registry before `docker run` so OPUS test jobs can pull the image reliably across runners.
* CI: add docker login before ATOM base image pull Authenticate with the Docker registry before pulling the ATOM base image so the workflow can fetch the container reliably across runners.
* fea(car): support custom group device Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com> * [fix]: test script code format Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com> * fea(car): support multi communication groups entity Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com> * fea(ag): support multi group allgather Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com> * fix: test script format Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com> * fea(car): add reduce_scatter interface Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com> * fix(car): custom group config Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com> * fix: custom comm group set method Signed-off-by: root <root@smci355-ccs-aus-n01-17.cs-aus.dcgpu> --------- Signed-off-by: TennyWang1223 <root@hjbog-srdc-24.amd.com> Signed-off-by: root <root@smci355-ccs-aus-n01-17.cs-aus.dcgpu> Co-authored-by: TennyWang1223 <root@hjbog-srdc-24.amd.com> Co-authored-by: root <root@smci355-ccs-aus-n01-17.cs-aus.dcgpu>
* [CK] Add StreamLLM sink token support for batch_prefill pipeline Update CK submodule to ROCm/rocm-libraries#6479 (commit feea6be) which makes kHasSink_ a real template parameter in fmha_fwd_batch_prefill_traits_, enabling _sink/_nsink kernel variant codegen for batch_prefill. AITER-side changes thread sink_size through the full call chain: Python (aiter/ops/mha.py): - Add sink_size: int = 0 to mha_batch_prefill_func, _mha_batch_prefill, mha_batch_prefill stub, and cmdGenFunc_mha_batch_prefill - Use has_effective_sink = sink_size > 0 and (causal or has_window_mask) for _sink/_nsink module name selection, matching C++ mask logic C++ interface (csrc/): - mha_batch_prefill_traits: expose has_sink parameter (was hardcoded false) - mha_fwd_batch_prefill.cu: derive has_sink = args.sink_size > 0 - mha_batch_prefill_kernels.cu: add int sink_size param, include it in mask_identify string, set args.sink_size = mask.sink, zero-initialize fmha_batch_prefill_args{} to avoid UB - PyBind (rocm_ops.hpp) and declaration (torch/mha_batch_prefill.h): add sink_size positional parameter after window_size_right Semantics clarified: - sink_size: number of first KV tokens always attended (sink phase loop) - sink_ptr[nhead]: fixed logit for a virtual sink token in softmax (null -> -inf, non-null -> user value); independent of sink_size - window_size_left=L means k in [abs_q-L, abs_q] (L+1 tokens), verified via code derivation (block_masking.hpp) and discriminating tests Verified: - No-sink paths numerically correct vs torch reference (max_diff<0.004) - sink_ptr virtual token semantics correct (max_diff<0.004) - window=1024 + sink_size=4 + sink_ptr: max_diff=0.000488 * [test] Add StreamLLM sink token tests for batch_prefill Add test_batch_prefill_sink pytest function and supporting helpers to op_tests/test_batch_prefill.py: - ref_masked_attention_with_sink: torch reference implementing sink semantics (first sink_size KV positions always valid, sink_ptr virtual token appended to softmax attention matrix) - run_batch_prefill_sink: runs both reference and CK kernel, compares with get_tolerances() thresholds - test_batch_prefill_sink: parametrized over: batch_size=[1,2], page_size=16, head_dim=128 num_qo_heads/num_kv_heads=[(8,1),(4,4)] (GQA + MHA) qo_len=[32,128] (window_left, kv_len)=[(128,512),(1024,2048)] <- real gap in both sink_size=[4,16] sink_ptr_value=[None, 0.0, 2.0] dtype=[bfloat16, float16] Verified manually with batch=1 GQA 8/1 window=128 sink=4 ptr=2.0, batch=2 MHA 4/4 window=1024 sink=16 ptr=None, and sink_ptr=0.0. * [style] Black format op_tests/test_batch_prefill.py * [fix] Address Copilot review on batch_prefill sink PR - aiter/ops/mha.py: add missing sink_size parameter to mha_batch_prefill_fake_tensors so its signature matches the real mha_batch_prefill op; without this torch.compile / fake-tensor mode raises TypeError when invoking the op. - op_tests/test_batch_prefill.py: vectorize the StreamLLM reference mask construction. The previous double Python loop did per-element assignment into a CUDA tensor, triggering O(seqlen_q * seqlen_k) GPU sync points (262144 at seqlen_k=2048) and would time out the parametrized tests. Replace with broadcasted index tensors and a single masked_fill_. * [test] Wire StreamLLM sink scenarios into CI __main__ block CI invokes test files via `python3 "$file"`, which only runs the __main__ block; pytest functions like test_batch_prefill_sink were therefore never exercised in CI. Add a second __main__ block at the end of the file that runs run_batch_prefill_sink over a small representative parameter sweep (2 window/kv combos x 2 sink_ptr settings, sink_size=4, bf16). The helpers are defined after the original __main__, so a separate trailing __main__ block is the minimal-blast-radius way to keep both invocations independent. * [CK] Bump submodule to include StreamLLM sink for batch_prefill Update CK submodule from 08792e0b3 to d22aafb48 to pull in [CK][fmha] Add StreamLLM sink support to batch_prefill pipeline (#6479).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist