From 1b7a9f91d348d05ca6fee28b1d23c81a12a46580 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Wed, 29 Apr 2026 11:20:28 -0700 Subject: [PATCH 01/14] DSV4 ATOM optims --- .../single_node/dsv4_fp4_mi355x_atom.sh | 98 ++----------------- perf-changelog.yaml | 8 ++ 2 files changed, 18 insertions(+), 88 deletions(-) diff --git a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh index 21708ba1d..cb802f4d1 100644 --- a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh @@ -58,7 +58,8 @@ fi export AITER_LOG_LEVEL=WARNING # Pull in the AITER pieces that matter for DSv4 FP4 on MI355X: -# * origin/main@dde1703e includes ROCm/aiter#2770 a16w4 MoE support. +# * origin/main@8c27e66f includes ROCm/aiter#2770 a16w4 MoE support and +# ROCm/aiter#2916 mhc_pre device-allocation fix. # * ROCm/aiter#2822 speeds up batched MXFP4 GEMM on gfx950. # * ROCm/aiter#2900 fixes MXFP4 scale padding for non-256 K. # * ROCm/aiter#2642 enables/fixes TP=4/8 MXFP4 MoE dispatch. @@ -66,16 +67,14 @@ export AITER_LOG_LEVEL=WARNING # eligible token counts to FlyDSL FP4 MoE kernels instead of default CK # heuristics when the image has the optional flydsl package. # -# ROCm/aiter#2916 is intentionally not cherry-picked here. That PR branch is -# based on a divergent fork and can conflict in unrelated test files; the -# narrow mhc_pre device fix is applied directly to installed aiter below. -# The non-mHC PRs cherry-pick cleanly over the pinned main SHA as of 2026-04-27. +# The open performance PRs cherry-pick cleanly over the pinned main SHA as +# of 2026-04-29. # Keep this as a runtime overlay until AMD publishes an ATOM image with these # AITER changes baked in; then remove this block and pin that image instead. if [ "${AITER_DSV4_PERF_STACK:-1}" = "1" ]; then AITER_PERF_REPO=${AITER_PERF_REPO:-https://github.com/ROCm/aiter.git} AITER_PERF_DIR=${AITER_PERF_DIR:-/tmp/aiter-dsv4-fp4-perf} - AITER_PERF_BASE_SHA=${AITER_PERF_BASE_SHA:-dde1703ebfc35d3724e07fc4e6e824023063494c} + AITER_PERF_BASE_SHA=${AITER_PERF_BASE_SHA:-8c27e66f8078c8e1e9ac4f55a5481e2a37db96f0} AITER_PERF_PATCH_REFS=( "${AITER_PERF_BATCHED_FP4_REF:-pull/2822/head}" "${AITER_PERF_MXFP4_SCALE_REF:-pull/2900/head}" @@ -128,7 +127,6 @@ if [ "${AITER_DSV4_PERF_STACK:-1}" = "1" ]; then fi python3 - <<'PYEOF' -import importlib.util import csv import os from pathlib import Path @@ -136,10 +134,15 @@ import aiter root = Path(aiter.__file__).resolve().parent moe = (root / "fused_moe.py").read_text() +mhc = (root / "ops" / "mhc.py").read_text() fp4_utils = (root / "utility" / "fp4_utils.py").read_text() dsv4_tuned_fmoe = Path(os.environ["AITER_DSV4_TUNED_FMOE_FILE"]) if os.environ.get("AITER_DSV4_TUNED_FMOE_FILE") else None required = { "native MXFP4 MoE skip_inter_quant": "skip_inter_quant" in moe, + "mhc_pre device allocation fix": ( + "device = residual.device" in mhc + and "dtype=dtypes.bf16, device=device" in mhc + ), "MXFP4 scaleN_pad fix": "scaleN_pad" in fp4_utils, "DSv4 FP4 tuned fMoE config": dsv4_tuned_fmoe is None or dsv4_tuned_fmoe.exists(), } @@ -181,87 +184,6 @@ else echo "WARN: AITER_DSV4_PERF_STACK=0; using image-provided aiter" fi -# Ensure the pure-Python part of ROCm/aiter#2916 is present. The AITER perf -# stack above already includes it; this block is kept as a fallback for -# AITER_DSV4_PERF_STACK=0 or future images that ship aiter without the fix. -export AITER_MHC_FIX_SHA="76ea1ed5b2a5f8176ed7a16b1640dd972546a925" -python3 - <<'PYEOF' -import importlib.util -import os -import sys -from pathlib import Path - -required_snippets = [ - " device = residual.device\n out_pad = torch.empty(", - "selected_splitk, m, (hc_mult3 + 31) // 32 * 32, dtype=dtypes.fp32, device=device", - "sqrsum = torch.empty(selected_splitk, m, dtype=dtypes.fp32, device=device)", - "post_mix = torch.empty(m, hc_mult, 1, dtype=dtypes.fp32, device=device)", - "comb_mix = torch.empty(m, hc_mult, hc_mult, dtype=dtypes.fp32, device=device)", - "layer_input = torch.empty(m, hidden_size, dtype=dtypes.bf16, device=device)", -] - -spec = importlib.util.find_spec("aiter.ops.mhc") -if spec is None or spec.origin is None: - sys.exit("FATAL: cannot locate installed aiter.ops.mhc for ROCm/aiter#2916 patch") - -mhc_path = Path(spec.origin) -source = mhc_path.read_text() - -if all(snippet in source for snippet in required_snippets): - print(f"aiter mhc device patch already present: {mhc_path}") - sys.exit(0) - -replacements = [ - ( - " out_pad = torch.empty(\n" - " selected_splitk, m, (hc_mult3 + 31) // 32 * 32, dtype=dtypes.fp32\n" - " )", - " device = residual.device\n" - " out_pad = torch.empty(\n" - " selected_splitk, m, (hc_mult3 + 31) // 32 * 32, dtype=dtypes.fp32, device=device\n" - " )", - ), - ( - " sqrsum = torch.empty(selected_splitk, m, dtype=dtypes.fp32)", - " sqrsum = torch.empty(selected_splitk, m, dtype=dtypes.fp32, device=device)", - ), - ( - " post_mix = torch.empty(m, hc_mult, 1, dtype=dtypes.fp32)", - " post_mix = torch.empty(m, hc_mult, 1, dtype=dtypes.fp32, device=device)", - ), - ( - " comb_mix = torch.empty(m, hc_mult, hc_mult, dtype=dtypes.fp32)", - " comb_mix = torch.empty(m, hc_mult, hc_mult, dtype=dtypes.fp32, device=device)", - ), - ( - " layer_input = torch.empty(m, hidden_size, dtype=dtypes.bf16)", - " layer_input = torch.empty(m, hidden_size, dtype=dtypes.bf16, device=device)", - ), -] - -missing = [old for old, _ in replacements if old not in source] -if missing: - sys.exit( - f"FATAL: {mhc_path} does not match the expected pre-ROCm/aiter#2916 " - f"source; refusing to patch mhc_pre blindly. Missing patterns: " - f"{[m.splitlines()[0].strip() for m in missing]}" - ) - -patched = source -for old, new in replacements: - patched = patched.replace(old, new, 1) - -mhc_path.write_text(patched) -patched_source = mhc_path.read_text() -if not all(snippet in patched_source for snippet in required_snippets): - sys.exit(f"FATAL: ROCm/aiter#2916 mhc device patch failed verification for {mhc_path}") - -print( - f"applied ROCm/aiter#2916 ({os.environ['AITER_MHC_FIX_SHA']}) " - f"mhc device patch: {mhc_path}" -) -PYEOF - # Apply ROCm/ATOM#650 (DSv4 PR1 skeleton) over the image's wheel-installed # atom. The chosen base image ships atom as a built wheel, not editable, so # we overlay an editable install from the PR branch at a pinned SHA. Bump diff --git a/perf-changelog.yaml b/perf-changelog.yaml index 2611cd000..64d43feb1 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -1999,3 +1999,11 @@ - "Add conc=8192 recipe for 1k1k: deepep mega_moe backend with cuda-graph-max-bs 1088, max-running-requests 8192, mem-fraction-static 0.80, swa-full-tokens-ratio 0.3, tokenizer-worker-num 16" - "conc=8192 enables SGLANG_OPT_USE_ONLINE_COMPRESS=1 and --stream-interval 30" pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1209 + +- config-keys: + - dsv4-fp4-mi355x-atom + description: + - "Advance the AITER DSv4 perf-stack base to ROCm/aiter main commit 8c27e66f, which includes merged ROCm/aiter#2916" + - "Remove the local runtime aiter/ops/mhc.py patcher and AITER_MHC_FIX_SHA fallback now the mhc_pre device-allocation fix is upstream" + - "Keep the open AITER performance PR cherry-picks (#2642, #2822, #2900); they are not merged upstream yet" + pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/XXX From 25cc815b9aabc0ffa5416e08bb28e81710358ccb Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Wed, 29 Apr 2026 15:20:45 -0700 Subject: [PATCH 02/14] flydsl --- benchmarks/single_node/dsv4_fp4_mi355x_atom.sh | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh index cb802f4d1..d3526bab1 100644 --- a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh @@ -58,8 +58,9 @@ fi export AITER_LOG_LEVEL=WARNING # Pull in the AITER pieces that matter for DSv4 FP4 on MI355X: -# * origin/main@8c27e66f includes ROCm/aiter#2770 a16w4 MoE support and -# ROCm/aiter#2916 mhc_pre device-allocation fix. +# * origin/main@bb4ea92e includes ROCm/aiter#2770 a16w4 MoE support, +# ROCm/aiter#2916 mhc_pre device-allocation fix, and ROCm/aiter#2924 +# FlyDSL GDR decode tuned configs. # * ROCm/aiter#2822 speeds up batched MXFP4 GEMM on gfx950. # * ROCm/aiter#2900 fixes MXFP4 scale padding for non-256 K. # * ROCm/aiter#2642 enables/fixes TP=4/8 MXFP4 MoE dispatch. @@ -74,7 +75,7 @@ export AITER_LOG_LEVEL=WARNING if [ "${AITER_DSV4_PERF_STACK:-1}" = "1" ]; then AITER_PERF_REPO=${AITER_PERF_REPO:-https://github.com/ROCm/aiter.git} AITER_PERF_DIR=${AITER_PERF_DIR:-/tmp/aiter-dsv4-fp4-perf} - AITER_PERF_BASE_SHA=${AITER_PERF_BASE_SHA:-8c27e66f8078c8e1e9ac4f55a5481e2a37db96f0} + AITER_PERF_BASE_SHA=${AITER_PERF_BASE_SHA:-bb4ea92eaf7a8420ab6bcc460095d310d02dd628} AITER_PERF_PATCH_REFS=( "${AITER_PERF_BATCHED_FP4_REF:-pull/2822/head}" "${AITER_PERF_MXFP4_SCALE_REF:-pull/2900/head}" @@ -143,6 +144,9 @@ required = { "device = residual.device" in mhc and "dtype=dtypes.bf16, device=device" in mhc ), + "FlyDSL GDR decode tuned configs": ( + root / "ops" / "flydsl" / "gdr_decode_tuned.jsonl" + ).exists(), "MXFP4 scaleN_pad fix": "scaleN_pad" in fp4_utils, "DSv4 FP4 tuned fMoE config": dsv4_tuned_fmoe is None or dsv4_tuned_fmoe.exists(), } From 07f02d2019c6cfba52e39250feea83f77c272226 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Wed, 29 Apr 2026 17:03:33 -0700 Subject: [PATCH 03/14] higher conc --- .github/configs/amd-master.yaml | 17 +- .../single_node/dsv4_fp4_mi355x_atom.sh | 473 +++++++++++++++++- perf-changelog.yaml | 10 +- 3 files changed, 475 insertions(+), 25 deletions(-) diff --git a/.github/configs/amd-master.yaml b/.github/configs/amd-master.yaml index 1faf3682b..2045a110d 100644 --- a/.github/configs/amd-master.yaml +++ b/.github/configs/amd-master.yaml @@ -1538,13 +1538,12 @@ dsv4-fp8-mi355x-vllm: search-space: - { tp: 8, conc-start: 1, conc-end: 1 } -# Day-0 single-sequence marker for DeepSeek-V4 on ATOM (ROCm/ATOM#650). -# PR1 of the ATOM DSv4 series still uses torch sparse-attention fallbacks -# that OOM once warmup/prefill batches multiple requests; keep CONC=1 until -# the AITER sparse-attention kernel / multi-request path lands upstream. -# --enforce-eager and ATOM_USE_TRITON_MOE=1 are required on gfx950. Image is -# the standard atom0.1.2.post MI355X base (matching qwen3.5-fp8-mi355x-atom); -# the DSv4 PR is overlaid at runtime by dsv4_fp4_mi355x_atom.sh at a pinned SHA. +# Day-0 DeepSeek-V4 on ATOM (ROCm/ATOM#650) with local runtime overlays. +# dsv4_fp4_mi355x_atom.sh patches PR650 to give each request persistent DSv4 +# KV/compressor/indexer cache slots, unblocking CONC>1 smoke coverage. The path +# still uses eager execution and request-looped DSv4 attention until upstream +# lands native multi-request sparse-attention/cache support, so keep this sweep +# conservative. dsv4-fp4-mi355x-atom: image: rocm/atom:rocm7.2.2_ubuntu24.04_py3.12_pytorch_release_2.10.0_atom0.1.2.post model: deepseek-ai/DeepSeek-V4-Pro @@ -1557,8 +1556,8 @@ dsv4-fp4-mi355x-atom: - isl: 1024 osl: 1024 search-space: - - { tp: 8, ep: 1, conc-start: 1, conc-end: 1 } + - { tp: 8, ep: 1, conc-start: 1, conc-end: 4 } - isl: 8192 osl: 1024 search-space: - - { tp: 8, ep: 1, conc-start: 1, conc-end: 1 } + - { tp: 8, ep: 1, conc-start: 1, conc-end: 4 } diff --git a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh index d3526bab1..ca5afbec1 100644 --- a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh @@ -19,17 +19,11 @@ fi echo "TP: $TP, CONC: $CONC, ISL: $ISL, OSL: $OSL, EP_SIZE: $EP_SIZE" -# ROCm/ATOM#650 is still a single-request marker for DSv4. Run -# 24953107645 showed CONC>1 fails in two ways: 1k warmup can exhaust the KV -# budget after sparse-attn temporaries raise peak memory, and 8k prefill OOMs -# in the torch sparse_attn fallback when two long requests are batched. Keep -# this fatal guard until ATOM lands the AITER sparse-attention / multi-request -# path for DeepSeek-V4. -if [ "$CONC" -ne 1 ]; then - echo "FATAL: ROCm/ATOM#650 DSv4 path is single-request only; CONC must be 1, got $CONC" >&2 - exit 1 -fi - +# ROCm/ATOM#650 is still a PR1 DSv4 skeleton. The local overlay below gives +# DSv4 persistent per-request cache slots so CONC>1 no longer corrupts the +# recurrent KV/compressor/indexer state. It is still eager and not vectorized +# across requests, so keep sweep points modest until upstream lands the native +# multi-request sparse-attention/cache path. if [ "$EP_SIZE" -ne 1 ]; then echo "FATAL: ROCm/ATOM#650 PR1 has not validated expert parallel serving; EP_SIZE must be 1, got $EP_SIZE" >&2 exit 1 @@ -193,7 +187,7 @@ fi # we overlay an editable install from the PR branch at a pinned SHA. Bump # this SHA when the PR moves; do not track the branch tip (the run becomes # a moving target if the branch is force-pushed). -ATOM_PR_SHA="cdbff359d3db7afd3801e28b38fc71253121ee84" +ATOM_PR_SHA="af17eb89ceb6370b0c1724aef3bf938e6baedecd" export ATOM_PR_DIR="/tmp/atom-pr650" if [ ! -d "$ATOM_PR_DIR/.git" ]; then @@ -286,6 +280,455 @@ else: print(f"DSv4 sparse_attn_v4 decode/chunk patch already present: {path}") PYEOF + # Local multi-request overlay for ROCm/ATOM#650. ATOM's scheduler passes + # DSv4 a token-flat batch, but PR650 treats every request as cache slot 0 + # (`kv_cache[:1]` and matching compressor/indexer state). Reuse ATOM's + # mamba-state slot allocator for DSv4, then run each sequence against its + # persistent slot. This fixes correctness for CONC>1; it is intentionally + # conservative and still loops requests until upstream vectorizes the DSv4 + # sparse-attention/cache path. + sed 's/^$/ /' <<'PATCH' | git apply +diff --git a/atom/model_engine/llm_engine.py b/atom/model_engine/llm_engine.py +index 8de9532..ddde446 100644 +--- a/atom/model_engine/llm_engine.py ++++ b/atom/model_engine/llm_engine.py +@@ -171,7 +171,16 @@ class InputOutputProcessor: + self.num_speculative_tokens = ( + self.config.speculative_config.num_speculative_tokens + ) +- mamba_model_types = {"qwen3_next", "qwen3_5_text", "qwen3_5_moe_text"} ++ mamba_model_types = { ++ "qwen3_next", ++ "qwen3_5_text", ++ "qwen3_5_moe_text", ++ "deepseek_v4", ++ "deepseek_v4_pro", ++ } +- if self.config.hf_config.model_type in mamba_model_types: ++ architectures = getattr(self.config.hf_config, "architectures", []) or [] ++ if self.config.hf_config.model_type in mamba_model_types or any( ++ "DeepseekV4" in arch for arch in architectures ++ ): + self.mamba_enabled = True + +diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py +index 72e9d84..598f2a5 100644 +--- a/atom/model_engine/model_runner.py ++++ b/atom/model_engine/model_runner.py +@@ -659,6 +659,14 @@ class ModelRunner: + ) + return False + ++ def is_deepseek_v4(self) -> bool: ++ model_type = getattr(self.hf_text_config, "model_type", None) ++ architectures = getattr(self.hf_text_config, "architectures", []) or [] ++ return model_type in ( ++ "deepseek_v4", ++ "deepseek_v4_pro", ++ ) or any("DeepseekV4" in arch for arch in architectures) ++ + def is_qwen_next(self) -> bool: + if not hasattr(self.hf_text_config, "model_type"): + return False +@@ -1250,9 +1256,10 @@ class ModelRunner: + + # GDN recurrent state: deduct mamba tensor memory from pool budget + mamba_per_slot = self._compute_mamba_per_slot_bytes() ++ needs_recurrent_slots = mamba_per_slot > 0 or self.is_deepseek_v4() + slots_per_req = 1 + self.num_spec_tokens + max_mamba_slots = ( +- config.max_num_seqs * slots_per_req if mamba_per_slot > 0 else 0 ++ config.max_num_seqs * slots_per_req if needs_recurrent_slots else 0 + ) + mamba_tensor_bytes = max_mamba_slots * mamba_per_slot + available_for_pool = available_for_kv - mamba_tensor_bytes +@@ -1270,7 +1277,7 @@ class ModelRunner: + # Store for BlockManager and allocate_kv_cache + config.mamba_equiv_per_req = mamba_equiv + config.max_mamba_slots = max_mamba_slots +- config.num_mamba_groups = config.max_num_seqs if mamba_per_slot > 0 else 0 ++ config.num_mamba_groups = config.max_num_seqs if needs_recurrent_slots else 0 + self.max_mamba_slots = max_mamba_slots + + num_kvcache_blocks = available_for_pool // block_bytes +@@ -1309,7 +1316,7 @@ class ModelRunner: + return { + "num_kvcache_blocks": num_kvcache_blocks, + "mamba_equiv_per_req": mamba_equiv, +- "num_mamba_groups": config.max_num_seqs if mamba_per_slot > 0 else 0, ++ "num_mamba_groups": config.max_num_seqs if needs_recurrent_slots else 0, + } + + def allocate_kv_cache(self, num_kvcache_blocks): +@@ -1782,6 +1789,13 @@ class ModelRunner: + ) + attn_metadata, positions = self.attn_metadata_builder.build(batch=batch, bs=bs) + context_bs = batch.total_seqs_num_prefill if is_prefill else scheduled_bs ++ if self.is_deepseek_v4(): ++ cache_slots = list(batch.mamba_state_slots) ++ if len(cache_slots) < context_bs: ++ cache_slots = list(range(context_bs)) ++ attn_metadata.dsv4_cache_slots = torch.tensor( ++ cache_slots[:context_bs], dtype=torch.int64, device=self.device ++ ) + + # graph_bs should be batch size (number of sequences), not token count + graph_bs = num_input_tokens if is_prefill else bs +diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py +index 46cf1b0..0d84c78 100644 +--- a/atom/models/deepseek_v4.py ++++ b/atom/models/deepseek_v4.py +@@ -506,7 +506,9 @@ class Compressor(nn.Module): + new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d] + return new_tensor + +- def forward(self, x: torch.Tensor, start_pos: int) -> Optional[torch.Tensor]: ++ def forward( ++ self, x: torch.Tensor, start_pos: int, cache_slot: int = 0 ++ ) -> Optional[torch.Tensor]: + """Compress KV for the input tokens. Writes into self.kv_cache when a + compression block boundary is hit; otherwise just buffers state and returns None. + +@@ -524,6 +526,7 @@ class Compressor(nn.Module): + if x.dim() == 2: + x = x.unsqueeze(0) # [num_tokens, dim] → [1, num_tokens, dim] + bsz, seqlen, _ = x.size() ++ slot = slice(cache_slot, cache_slot + bsz) + ratio = self.compress_ratio + overlap = self.overlap + d = self.head_dim +@@ -545,16 +548,16 @@ class Compressor(nn.Module): + # Save the last `ratio` overlap-slice tokens into kv_state for use + # by the next decode call's overlap window. + if overlap and cutoff >= ratio: +- self.kv_state[:bsz, :ratio] = kv[:, cutoff - ratio : cutoff] +- self.score_state[:bsz, :ratio] = ( ++ self.kv_state[slot, :ratio] = kv[:, cutoff - ratio : cutoff] ++ self.score_state[slot, :ratio] = ( + score[:, cutoff - ratio : cutoff] + self.ape + ) + # Save the trailing partial block (remainder tokens) into kv_state. + if remainder > 0: +- kv, self.kv_state[:bsz, offset : offset + remainder] = kv.split( ++ kv, self.kv_state[slot, offset : offset + remainder] = kv.split( + [cutoff, remainder], dim=1 + ) +- self.score_state[:bsz, offset : offset + remainder] = ( ++ self.score_state[slot, offset : offset + remainder] = ( + score[:, cutoff:] + self.ape[:remainder] + ) + score = score[:, :cutoff] +@@ -570,20 +573,20 @@ class Compressor(nn.Module): + should_compress = (start_pos + 1) % self.compress_ratio == 0 + score = score + self.ape[start_pos % ratio] + if overlap: +- self.kv_state[:bsz, ratio + start_pos % ratio] = kv.squeeze(1) +- self.score_state[:bsz, ratio + start_pos % ratio] = score.squeeze(1) ++ self.kv_state[slot, ratio + start_pos % ratio] = kv.squeeze(1) ++ self.score_state[slot, ratio + start_pos % ratio] = score.squeeze(1) + if should_compress: + kv_state = torch.cat( + [ +- self.kv_state[:bsz, :ratio, :d], +- self.kv_state[:bsz, ratio:, d:], ++ self.kv_state[slot, :ratio, :d], ++ self.kv_state[slot, ratio:, d:], + ], + dim=1, + ) + score_state = torch.cat( + [ +- self.score_state[:bsz, :ratio, :d], +- self.score_state[:bsz, ratio:, d:], ++ self.score_state[slot, :ratio, :d], ++ self.score_state[slot, ratio:, d:], + ], + dim=1, + ) +@@ -591,14 +594,14 @@ class Compressor(nn.Module): + dim=1, keepdim=True + ) + # Roll: the just-completed window becomes the next overlap window. +- self.kv_state[:bsz, :ratio] = self.kv_state[:bsz, ratio:] +- self.score_state[:bsz, :ratio] = self.score_state[:bsz, ratio:] ++ self.kv_state[slot, :ratio] = self.kv_state[slot, ratio:] ++ self.score_state[slot, :ratio] = self.score_state[slot, ratio:] + else: +- self.kv_state[:bsz, start_pos % ratio] = kv.squeeze(1) +- self.score_state[:bsz, start_pos % ratio] = score.squeeze(1) ++ self.kv_state[slot, start_pos % ratio] = kv.squeeze(1) ++ self.score_state[slot, start_pos % ratio] = score.squeeze(1) + if should_compress: + kv = ( +- self.kv_state[:bsz] * self.score_state[:bsz].softmax(dim=1) ++ self.kv_state[slot] * self.score_state[slot].softmax(dim=1) + ).sum(dim=1, keepdim=True) + + if not should_compress: +@@ -622,9 +625,9 @@ class Compressor(nn.Module): + act_quant_inplace(kv[..., :-rd], 64, self.scale_fmt) + + if start_pos == 0: +- self.kv_cache[:bsz, : seqlen // ratio] = kv ++ self.kv_cache[slot, : seqlen // ratio] = kv + else: +- self.kv_cache[:bsz, start_pos // ratio] = kv.squeeze(1) ++ self.kv_cache[slot, start_pos // ratio] = kv.squeeze(1) + return kv + + +@@ -696,6 +699,7 @@ class Indexer(nn.Module): + qr: torch.Tensor, + start_pos: int, + offset: int, ++ cache_slot: int = 0, + ) -> torch.Tensor: + """Compute sparse top-k indices over the indexer's compressed KV cache. + +@@ -715,6 +719,7 @@ class Indexer(nn.Module): + ratio = self.compress_ratio + rd = self.rope_head_dim + end_pos = start_pos + seqlen ++ slot = slice(cache_slot, cache_slot + 1) + + # Lazy plumb the indexer's kv_cache + freqs_cis into its compressor. + if self.compressor.kv_cache is None: +@@ -729,7 +734,7 @@ class Indexer(nn.Module): + fp4_act_quant_inplace(q, _FP4_BLOCK_SIZE) + + # ----- Indexer KV (Compressor takes 2D, mutates kv_cache) ----- +- self.compressor(x, start_pos) ++ self.compressor(x, start_pos, cache_slot) + # weights_proj is ATOM Linear → 2D input; restore B=1 dim for einsum. + weights = ( + self.weights_proj(x) * (self.softmax_scale * self.n_heads**-0.5) +@@ -737,7 +742,7 @@ class Indexer(nn.Module): + + # ----- Index score ----- + index_score = torch.einsum( +- "bshd,btd->bsht", q, self.kv_cache[:1, : end_pos // ratio] ++ "bshd,btd->bsht", q, self.kv_cache[slot, : end_pos // ratio] + ) + index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2) + +@@ -959,7 +964,9 @@ class DeepseekV4Attention(nn.Module): + + self.wo_a.quant_type = _QT.No + +- def forward(self, x: torch.Tensor, start_pos: int) -> torch.Tensor: ++ def forward( ++ self, x: torch.Tensor, start_pos: int, cache_slot: int = 0 ++ ) -> torch.Tensor: + """Compute attention for `x` at absolute position `start_pos`. + + Args: +@@ -978,6 +985,7 @@ class DeepseekV4Attention(nn.Module): + win = self.window_size + ratio = self.compress_ratio + rd = self.rope_head_dim ++ slot = slice(cache_slot, cache_slot + 1) + + # First-call plumbing: hand the (compressed-half) KV cache + freqs_cis + # to the compressor / indexer. +@@ -992,14 +1000,14 @@ class DeepseekV4Attention(nn.Module): + # with garbage. Real prefill only overwrites a few slots, leaving + # stale warmup data that poisons decode attention. + if start_pos == 0: +- self.kv_cache.zero_() ++ self.kv_cache[slot].zero_() + if self.compress_ratio: +- self.compressor.kv_state.zero_() +- self.compressor.score_state.fill_(float("-inf")) ++ self.compressor.kv_state[slot].zero_() ++ self.compressor.score_state[slot].fill_(float("-inf")) + if self.indexer is not None: +- self.indexer.kv_cache.zero_() +- self.indexer.compressor.kv_state.zero_() +- self.indexer.compressor.score_state.fill_(float("-inf")) ++ self.indexer.kv_cache[slot].zero_() ++ self.indexer.compressor.kv_state[slot].zero_() ++ self.indexer.compressor.score_state[slot].fill_(float("-inf")) + + # ----- Q: low-rank projection + per-head RMSNorm + partial RoPE ----- + # ATOM TP linears require 2D inputs; subsequent ops (RoPE, sparse_attn) +@@ -1023,7 +1031,7 @@ class DeepseekV4Attention(nn.Module): + if self.compress_ratio: + offset = kv.size(1) if start_pos == 0 else win + if self.indexer is not None: +- compress_topk_idxs = self.indexer(x, qr, start_pos, offset) ++ compress_topk_idxs = self.indexer(x, qr, start_pos, offset, cache_slot) + else: + compress_topk_idxs = _get_compress_topk_idxs( + ratio, 1, seqlen, start_pos, offset, device=x.device +@@ -1037,26 +1045,26 @@ class DeepseekV4Attention(nn.Module): + # implicit B=1.) ----- + if start_pos == 0: + if seqlen <= win: +- self.kv_cache[:1, :seqlen] = kv ++ self.kv_cache[slot, :seqlen] = kv + else: + cutoff = seqlen % win + ( +- self.kv_cache[:1, cutoff:win], +- self.kv_cache[:1, :cutoff], ++ self.kv_cache[slot, cutoff:win], ++ self.kv_cache[slot, :cutoff], + ) = kv[ + :, -win: + ].split([win - cutoff, cutoff], dim=1) + if self.compress_ratio: +- if (kv_compress := self.compressor(x, start_pos)) is not None: ++ if (kv_compress := self.compressor(x, start_pos, cache_slot)) is not None: + kv = torch.cat([kv, kv_compress], dim=1) + o = sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale) + else: +- self.kv_cache[:1, start_pos % win] = kv.squeeze(1) ++ self.kv_cache[slot, start_pos % win] = kv.squeeze(1) + if self.compress_ratio: +- self.compressor(x, start_pos) ++ self.compressor(x, start_pos, cache_slot) + o = sparse_attn( + q, +- self.kv_cache[:1], ++ self.kv_cache[slot], + self.attn_sink, + topk_idxs, + self.softmax_scale, +@@ -1599,6 +1607,7 @@ class Block(nn.Module): + x: torch.Tensor, + start_pos: int, + input_ids: Optional[torch.Tensor], ++ cache_slot: int = 0, + ) -> torch.Tensor: + # ----- Attention sub-layer with mHC mixing ----- + residual = x +@@ -1606,7 +1615,7 @@ class Block(nn.Module): + x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base + ) + x = self.attn_norm(x) +- x = self.attn(x, start_pos) ++ x = self.attn(x, start_pos, cache_slot) + x = self.hc_post(x, residual, post, comb) + + # ----- FFN sub-layer with mHC mixing ----- +@@ -1821,11 +1830,30 @@ class DeepseekV4Model(nn.Module): + self.hc_head_base = nn.Parameter(torch.empty(hc_mult, dtype=torch.float32)) + self.hc_head_scale = nn.Parameter(torch.empty(1, dtype=torch.float32)) + ++ def _forward_one( ++ self, ++ input_ids: torch.Tensor, ++ start_pos: int, ++ cache_slot: int, ++ ) -> torch.Tensor: ++ h = self.embed(input_ids) # [num_tokens, dim] ++ # Expand to hc_mult copies for Hyper-Connections: [num_tokens, hc, dim] ++ h = h.unsqueeze(-2).repeat(1, self.hc_mult, 1) ++ ++ for layer in self.layers: ++ h = layer(h, start_pos, input_ids, cache_slot) ++ ++ logits = self.head( ++ h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm ++ ) ++ return logits ++ + @torch.inference_mode() + def forward( + self, + input_ids: torch.Tensor, + start_pos: int = 0, ++ positions: Optional[torch.Tensor] = None, + **model_kwargs: dict, + ) -> torch.Tensor: + """Forward. +@@ -1844,17 +1872,51 @@ class DeepseekV4Model(nn.Module): + input_ids.size(0) == 1 + ), "B>1 batched input_ids needs attn_metadata; not supported yet" + input_ids = input_ids.flatten() +- h = self.embed(input_ids) # [num_tokens, dim] +- # Expand to hc_mult copies for Hyper-Connections: [num_tokens, hc, dim] +- h = h.unsqueeze(-2).repeat(1, self.hc_mult, 1) ++ if positions is None: ++ positions = torch.arange( ++ start_pos, ++ start_pos + input_ids.numel(), ++ device=input_ids.device, ++ dtype=torch.int64, ++ ) ++ else: ++ positions = positions.flatten() + +- for layer in self.layers: +- h = layer(h, start_pos, input_ids) ++ attn_metadata = None ++ context = None ++ try: ++ from atom.utils.forward_context import get_forward_context + +- logits = self.head( +- h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm +- ) +- return logits ++ forward_context = get_forward_context() ++ attn_metadata = forward_context.attn_metadata ++ context = forward_context.context ++ except Exception: ++ pass ++ ++ cu_seqlens_q = getattr(attn_metadata, "cu_seqlens_q", None) ++ if cu_seqlens_q is None or context is None or context.batch_size <= 1: ++ seq_start = int(positions[0].item()) if positions.numel() else int(start_pos) ++ cache_slots = getattr(attn_metadata, "dsv4_cache_slots", None) ++ cache_slot = int(cache_slots[0].item()) if cache_slots is not None else 0 ++ return self._forward_one(input_ids, seq_start, cache_slot) ++ ++ num_seqs = int(context.batch_size) ++ cache_slots = getattr(attn_metadata, "dsv4_cache_slots", None) ++ if cache_slots is None or cache_slots.numel() < num_seqs: ++ cache_slots = torch.arange(num_seqs, device=input_ids.device, dtype=torch.int64) ++ ++ logits = [] ++ for seq_idx in range(num_seqs): ++ start = int(cu_seqlens_q[seq_idx].item()) ++ end = int(cu_seqlens_q[seq_idx + 1].item()) ++ if end <= start: ++ continue ++ seq_start = int(positions[start].item()) ++ cache_slot = int(cache_slots[seq_idx].item()) ++ logits.append(self._forward_one(input_ids[start:end], seq_start, cache_slot)) ++ if not logits: ++ return self._forward_one(input_ids[:1], int(start_pos), 0) ++ return torch.cat(logits, dim=0) + + + class DeepseekV4ForCausalLM(nn.Module): +@@ -1918,6 +1980,9 @@ class DeepseekV4ForCausalLM(nn.Module): + # config lacks `quantization_config` (e.g. dummy / toy validation), + # this still works — base spec is QuantType.No. + self.args.quant_config = make_v4_quant_config(self.hf_config) ++ self.args.max_batch_size = max( ++ self.args.max_batch_size, int(getattr(config, "max_num_seqs", 1)) ++ ) + self.model = DeepseekV4Model(args=self.args) + + def forward( +@@ -1929,7 +1994,12 @@ class DeepseekV4ForCausalLM(nn.Module): + **model_kwargs: dict, + ) -> torch.Tensor: + start_pos = int(positions[0].item()) if positions is not None else 0 +- return self.model(input_ids=input_ids, start_pos=start_pos, **model_kwargs) ++ return self.model( ++ input_ids=input_ids, ++ start_pos=start_pos, ++ positions=positions, ++ **model_kwargs, ++ ) + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + # In V4, the LM head is fused into DeepseekV4Model.forward (it consumes +PATCH + # --no-deps: don't churn the image's pinned ROCm/torch/triton/aiter. # --force-reinstall: replace the wheel-installed atom with the editable copy. pip install --no-deps --force-reinstall -e . @@ -421,9 +864,9 @@ export ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS=${ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS:- # KV/GDN-mamba allocator overshoot the GPU budget ("GDN mamba tensor # exceeds available KV budget"), and using 1 hangs warmup at 0% GPU. 4 # is the minimum we've seen complete warmup successfully (also the PR's -# offline repro value). The PR1 kv_cache[:1,...] hardcode in -# deepseek_v4.py means any forward with batch>1 silently corrupts -# non-slot-0 lanes; eval (gsm8k) at conc>1 is the canary. +# offline repro value). The local PR650 overlay above maps each request to a +# persistent DSv4 cache slot; without it, deepseek_v4.py's `kv_cache[:1]` +# writes corrupt non-slot-0 lanes at CONC>1. MAX_NUM_SEQS=$(( CONC < 4 ? 4 : CONC )) MAX_NUM_BATCHED_TOKENS=${MAX_NUM_BATCHED_TOKENS:-$MAX_MODEL_LEN_VALUE} python3 -m atom.entrypoints.openai_server \ diff --git a/perf-changelog.yaml b/perf-changelog.yaml index d6d144fa3..c63086425 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -2013,4 +2013,12 @@ - "Advance the AITER DSv4 perf-stack base to ROCm/aiter main commit 8c27e66f, which includes merged ROCm/aiter#2916" - "Remove the local runtime aiter/ops/mhc.py patcher and AITER_MHC_FIX_SHA fallback now the mhc_pre device-allocation fix is upstream" - "Keep the open AITER performance PR cherry-picks (#2642, #2822, #2900); they are not merged upstream yet" - pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1229 + pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/XXX + +- config-keys: + - dsv4-fp4-mi355x-atom + description: + - "Enable DSv4 ATOM multi-request smoke coverage by adding a local ROCm/ATOM#650 overlay that assigns persistent per-request cache slots for DSv4 KV/compressor/indexer state" + - "Expand MI355X ATOM DSv4 sweep from CONC=1 only to CONC=1/2/4 for 1k1k and 8k1k" + - "Bump the overlaid ROCm/ATOM#650 SHA to af17eb8; the patch unblocks correctness but still runs DSv4 requests sequentially until upstream vectorized sparse-attention/cache support lands" + pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/XXX From 21c7d873d05373e6573df4779db75d39fed02f0c Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Wed, 29 Apr 2026 21:09:45 -0700 Subject: [PATCH 04/14] higher conc optims --- .github/configs/amd-master.yaml | 5 +- .../single_node/dsv4_fp4_mi355x_atom.sh | 209 +++++++++++++++--- 2 files changed, 184 insertions(+), 30 deletions(-) diff --git a/.github/configs/amd-master.yaml b/.github/configs/amd-master.yaml index 2045a110d..eaf7d36fa 100644 --- a/.github/configs/amd-master.yaml +++ b/.github/configs/amd-master.yaml @@ -1541,9 +1541,8 @@ dsv4-fp8-mi355x-vllm: # Day-0 DeepSeek-V4 on ATOM (ROCm/ATOM#650) with local runtime overlays. # dsv4_fp4_mi355x_atom.sh patches PR650 to give each request persistent DSv4 # KV/compressor/indexer cache slots, unblocking CONC>1 smoke coverage. The path -# still uses eager execution and request-looped DSv4 attention until upstream -# lands native multi-request sparse-attention/cache support, so keep this sweep -# conservative. +# still uses eager execution and per-sequence sparse attention, but batches +# attention projections, mHC, and MoE/FFN layer-by-layer across active requests. dsv4-fp4-mi355x-atom: image: rocm/atom:rocm7.2.2_ubuntu24.04_py3.12_pytorch_release_2.10.0_atom0.1.2.post model: deepseek-ai/DeepSeek-V4-Pro diff --git a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh index ca5afbec1..bf74fe40b 100644 --- a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh @@ -21,9 +21,8 @@ echo "TP: $TP, CONC: $CONC, ISL: $ISL, OSL: $OSL, EP_SIZE: $EP_SIZE" # ROCm/ATOM#650 is still a PR1 DSv4 skeleton. The local overlay below gives # DSv4 persistent per-request cache slots so CONC>1 no longer corrupts the -# recurrent KV/compressor/indexer state. It is still eager and not vectorized -# across requests, so keep sweep points modest until upstream lands the native -# multi-request sparse-attention/cache path. +# recurrent KV/compressor/indexer state. It keeps sparse attention per sequence, +# but batches attention projections, mHC, and MoE/FFN work layer-by-layer. if [ "$EP_SIZE" -ne 1 ]; then echo "FATAL: ROCm/ATOM#650 PR1 has not validated expert parallel serving; EP_SIZE must be 1, got $EP_SIZE" >&2 exit 1 @@ -283,11 +282,12 @@ PYEOF # Local multi-request overlay for ROCm/ATOM#650. ATOM's scheduler passes # DSv4 a token-flat batch, but PR650 treats every request as cache slot 0 # (`kv_cache[:1]` and matching compressor/indexer state). Reuse ATOM's - # mamba-state slot allocator for DSv4, then run each sequence against its - # persistent slot. This fixes correctness for CONC>1; it is intentionally - # conservative and still loops requests until upstream vectorizes the DSv4 + # mamba-state slot allocator for DSv4, then split only sparse attention and + # cache mutation per sequence while batching attention projections, mHC, and + # MoE/FFN layer-by-layer. This fixes correctness for CONC>1 and avoids the + # worst all-layers-per-request loop until upstream vectorizes the DSv4 # sparse-attention/cache path. - sed 's/^$/ /' <<'PATCH' | git apply + sed 's/^$/ /' <<'PATCH' | git apply --recount diff --git a/atom/model_engine/llm_engine.py b/atom/model_engine/llm_engine.py index 8de9532..ddde446 100644 --- a/atom/model_engine/llm_engine.py @@ -560,7 +560,7 @@ index 46cf1b0..0d84c78 100644 else: compress_topk_idxs = _get_compress_topk_idxs( ratio, 1, seqlen, start_pos, offset, device=x.device -@@ -1037,26 +1045,26 @@ class DeepseekV4Attention(nn.Module): +@@ -1037,42 +1045,136 @@ class DeepseekV4Attention(nn.Module): # implicit B=1.) ----- if start_pos == 0: if seqlen <= win: @@ -594,7 +594,116 @@ index 46cf1b0..0d84c78 100644 self.attn_sink, topk_idxs, self.softmax_scale, -@@ -1599,6 +1607,7 @@ class Block(nn.Module): + ) + + # Inverse RoPE on output's rope dims to remove absolute-position contribution + # carried in by the value-side RoPE of the KV entries. + _apply_rotary_emb(o[..., -rd:], freqs_cis, inverse=True) + + # ----- Grouped output LoRA ----- + # o: [1, S, H, D] → drop B; reshape into groups for the einsum. + o = o.squeeze(0).view(seqlen, self.n_local_groups, -1) # [S, g, H/g * D] + wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1) + o = torch.einsum("sgd,grd->sgr", o, wo_a) # [S, g, o_lora_rank] + x = self.wo_b(o.flatten(1)) # 2D [S, dim] + return x + ++ def forward_batched( ++ self, x: torch.Tensor, seq_meta: list[tuple[int, int, int, int]] ++ ) -> torch.Tensor: ++ assert ( ++ x.dim() == 2 ++ ), f"DeepseekV4Attention expects 2D [num_tokens, dim], got {x.shape}" ++ total_tokens = x.size(0) ++ win = self.window_size ++ ratio = self.compress_ratio ++ rd = self.rope_head_dim ++ ++ if self.compress_ratio and self.compressor.kv_cache is None: ++ self.compressor.kv_cache = self.kv_cache[:, win:] ++ self.compressor.freqs_cis = self.freqs_cis ++ if self.indexer is not None: ++ self.indexer.freqs_cis = self.freqs_cis ++ ++ qr_all = self.q_norm(self.wq_a(x)) ++ q_all = self.wq_b(qr_all).view(total_tokens, self.n_local_heads, self.head_dim) ++ q_all = q_all * torch.rsqrt(q_all.square().mean(-1, keepdim=True) + self.eps) ++ kv_all = self.kv_norm(self.wkv(x)).view(total_tokens, self.head_dim) ++ ++ outputs = [] ++ for start, end, start_pos, cache_slot in seq_meta: ++ seqlen = end - start ++ freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] ++ slot = slice(cache_slot, cache_slot + 1) ++ ++ if start_pos == 0: ++ self.kv_cache[slot].zero_() ++ if self.compress_ratio: ++ self.compressor.kv_state[slot].zero_() ++ self.compressor.score_state[slot].fill_(float("-inf")) ++ if self.indexer is not None: ++ self.indexer.kv_cache[slot].zero_() ++ self.indexer.compressor.kv_state[slot].zero_() ++ self.indexer.compressor.score_state[slot].fill_(float("-inf")) ++ ++ q = q_all[start:end].unsqueeze(0) ++ _apply_rotary_emb(q[..., -rd:], freqs_cis) ++ kv = kv_all[start:end].unsqueeze(0) ++ _apply_rotary_emb(kv[..., -rd:], freqs_cis) ++ act_quant_inplace(kv[..., :-rd], 64, self.scale_fmt) ++ ++ topk_idxs = _get_window_topk_idxs( ++ win, 1, seqlen, start_pos, device=x.device ++ ) ++ if self.compress_ratio: ++ offset = kv.size(1) if start_pos == 0 else win ++ if self.indexer is not None: ++ compress_topk_idxs = self.indexer( ++ x[start:end], qr_all[start:end], start_pos, offset, cache_slot ++ ) ++ else: ++ compress_topk_idxs = _get_compress_topk_idxs( ++ ratio, 1, seqlen, start_pos, offset, device=x.device ++ ) ++ topk_idxs = torch.cat([topk_idxs, compress_topk_idxs], dim=-1) ++ topk_idxs = topk_idxs.int() ++ ++ if start_pos == 0: ++ if seqlen <= win: ++ self.kv_cache[slot, :seqlen] = kv ++ else: ++ cutoff = seqlen % win ++ ( ++ self.kv_cache[slot, cutoff:win], ++ self.kv_cache[slot, :cutoff], ++ ) = kv[:, -win:].split([win - cutoff, cutoff], dim=1) ++ if self.compress_ratio: ++ kv_compress = self.compressor(x[start:end], start_pos, cache_slot) ++ if kv_compress is not None: ++ kv = torch.cat([kv, kv_compress], dim=1) ++ o = sparse_attn(q, kv, self.attn_sink, topk_idxs, self.softmax_scale) ++ else: ++ self.kv_cache[slot, start_pos % win] = kv.squeeze(1) ++ if self.compress_ratio: ++ self.compressor(x[start:end], start_pos, cache_slot) ++ o = sparse_attn( ++ q, ++ self.kv_cache[slot], ++ self.attn_sink, ++ topk_idxs, ++ self.softmax_scale, ++ ) ++ ++ _apply_rotary_emb(o[..., -rd:], freqs_cis, inverse=True) ++ outputs.append(o.squeeze(0)) ++ ++ o = torch.cat(outputs, dim=0).view(total_tokens, self.n_local_groups, -1) ++ wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1) ++ o = torch.einsum("sgd,grd->sgr", o, wo_a) ++ return self.wo_b(o.flatten(1)) ++ + +@@ -1599,6 +1701,7 @@ class Block(nn.Module): x: torch.Tensor, start_pos: int, input_ids: Optional[torch.Tensor], @@ -602,7 +711,7 @@ index 46cf1b0..0d84c78 100644 ) -> torch.Tensor: # ----- Attention sub-layer with mHC mixing ----- residual = x -@@ -1606,7 +1615,7 @@ class Block(nn.Module): +@@ -1606,7 +1709,7 @@ class Block(nn.Module): x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base ) x = self.attn_norm(x) @@ -611,7 +720,7 @@ index 46cf1b0..0d84c78 100644 x = self.hc_post(x, residual, post, comb) # ----- FFN sub-layer with mHC mixing ----- -@@ -1821,11 +1830,30 @@ class DeepseekV4Model(nn.Module): +@@ -1821,11 +1924,81 @@ class DeepseekV4Model(nn.Module): self.hc_head_base = nn.Parameter(torch.empty(hc_mult, dtype=torch.float32)) self.hc_head_scale = nn.Parameter(torch.empty(1, dtype=torch.float32)) @@ -632,6 +741,57 @@ index 46cf1b0..0d84c78 100644 + h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.norm + ) + return logits ++ ++ def _head_tokens(self, h: torch.Tensor) -> torch.Tensor: ++ x = self.head.hc_head( ++ h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base ++ ) ++ return F.linear(self.norm(x).float(), self.head.weight) ++ ++ def _forward_layerwise_batched( ++ self, ++ input_ids: torch.Tensor, ++ positions: torch.Tensor, ++ cu_seqlens_q: torch.Tensor, ++ cache_slots: torch.Tensor, ++ num_seqs: int, ++ ) -> torch.Tensor: ++ seq_meta: list[tuple[int, int, int, int]] = [] ++ last_indices: list[int] = [] ++ for seq_idx in range(num_seqs): ++ start = int(cu_seqlens_q[seq_idx].item()) ++ end = int(cu_seqlens_q[seq_idx + 1].item()) ++ if end <= start: ++ continue ++ seq_start = int(positions[start].item()) ++ cache_slot = int(cache_slots[seq_idx].item()) ++ seq_meta.append((start, end, seq_start, cache_slot)) ++ last_indices.append(end - 1) ++ if not seq_meta: ++ return self._forward_one(input_ids[:1], 0, 0) ++ ++ h = self.embed(input_ids) ++ h = h.unsqueeze(-2).repeat(1, self.hc_mult, 1) ++ ++ for layer in self.layers: ++ residual = h ++ x, post, comb = layer.hc_pre( ++ h, layer.hc_attn_fn, layer.hc_attn_scale, layer.hc_attn_base ++ ) ++ x = layer.attn_norm(x) ++ x = layer.attn.forward_batched(x, seq_meta) ++ h = layer.hc_post(x, residual, post, comb) ++ ++ residual = h ++ x, post, comb = layer.hc_pre( ++ h, layer.hc_ffn_fn, layer.hc_ffn_scale, layer.hc_ffn_base ++ ) ++ x = layer.ffn_norm(x) ++ x = layer.ffn(x, input_ids) ++ h = layer.hc_post(x, residual, post, comb) ++ ++ last_indices_t = torch.tensor(last_indices, device=h.device, dtype=torch.long) ++ return self._head_tokens(h.index_select(0, last_indices_t)) + @torch.inference_mode() def forward( @@ -642,7 +802,7 @@ index 46cf1b0..0d84c78 100644 **model_kwargs: dict, ) -> torch.Tensor: """Forward. -@@ -1844,17 +1872,51 @@ class DeepseekV4Model(nn.Module): +@@ -1844,17 +2017,42 @@ class DeepseekV4Model(nn.Module): input_ids.size(0) == 1 ), "B>1 batched input_ids needs attn_metadata; not supported yet" input_ids = input_ids.flatten() @@ -688,22 +848,13 @@ index 46cf1b0..0d84c78 100644 + if cache_slots is None or cache_slots.numel() < num_seqs: + cache_slots = torch.arange(num_seqs, device=input_ids.device, dtype=torch.int64) + -+ logits = [] -+ for seq_idx in range(num_seqs): -+ start = int(cu_seqlens_q[seq_idx].item()) -+ end = int(cu_seqlens_q[seq_idx + 1].item()) -+ if end <= start: -+ continue -+ seq_start = int(positions[start].item()) -+ cache_slot = int(cache_slots[seq_idx].item()) -+ logits.append(self._forward_one(input_ids[start:end], seq_start, cache_slot)) -+ if not logits: -+ return self._forward_one(input_ids[:1], int(start_pos), 0) -+ return torch.cat(logits, dim=0) ++ return self._forward_layerwise_batched( ++ input_ids, positions, cu_seqlens_q, cache_slots, num_seqs ++ ) class DeepseekV4ForCausalLM(nn.Module): -@@ -1918,6 +1980,9 @@ class DeepseekV4ForCausalLM(nn.Module): +@@ -1918,6 +2116,9 @@ class DeepseekV4ForCausalLM(nn.Module): # config lacks `quantization_config` (e.g. dummy / toy validation), # this still works — base spec is QuantType.No. self.args.quant_config = make_v4_quant_config(self.hf_config) @@ -713,7 +864,7 @@ index 46cf1b0..0d84c78 100644 self.model = DeepseekV4Model(args=self.args) def forward( -@@ -1929,7 +1994,12 @@ class DeepseekV4ForCausalLM(nn.Module): +@@ -1929,7 +2130,12 @@ class DeepseekV4ForCausalLM(nn.Module): **model_kwargs: dict, ) -> torch.Tensor: start_pos = int(positions[0].item()) if positions is not None else 0 @@ -868,7 +1019,11 @@ export ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS=${ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS:- # persistent DSv4 cache slot; without it, deepseek_v4.py's `kv_cache[:1]` # writes corrupt non-slot-0 lanes at CONC>1. MAX_NUM_SEQS=$(( CONC < 4 ? 4 : CONC )) -MAX_NUM_BATCHED_TOKENS=${MAX_NUM_BATCHED_TOKENS:-$MAX_MODEL_LEN_VALUE} +# Allow prefill batching again. The layer-wise DSv4 overlay splits attention by +# sequence before sparse_attn, so two 8k-ish prompts no longer become one giant +# sparse-attention problem, while MoE/FFN still sees the larger token batch. +DEFAULT_MAX_NUM_BATCHED_TOKENS=$(( MAX_MODEL_LEN_VALUE > 16384 ? MAX_MODEL_LEN_VALUE : 16384 )) +MAX_NUM_BATCHED_TOKENS=${MAX_NUM_BATCHED_TOKENS:-$DEFAULT_MAX_NUM_BATCHED_TOKENS} python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ From ce4cb44af16b8a53b5409886bb196912b4eeacdf Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Thu, 30 Apr 2026 06:43:50 -0700 Subject: [PATCH 05/14] Optimize DSv4 ATOM profiling and decode batching --- .github/workflows/claude.yml | 22 +++-- .github/workflows/profile.yml | 33 +++++-- benchmarks/benchmark_lib.sh | 25 ++++-- .../single_node/dsr1_fp4_mi355x_atom.sh | 6 +- .../single_node/dsr1_fp4_mi355x_atom_mtp.sh | 2 + .../single_node/dsr1_fp8_mi355x_atom.sh | 6 +- .../single_node/dsr1_fp8_mi355x_atom_mtp.sh | 2 + .../single_node/dsv4_fp4_mi355x_atom.sh | 86 +++++++++++++++++++ .../single_node/glm5.1_fp4_mi355x_atom.sh | 2 + .../single_node/glm5_fp8_mi355x_atom.sh | 2 + .../single_node/gptoss_fp4_mi355x_atom.sh | 6 +- .../single_node/kimik2.5_fp4_mi355x_atom.sh | 2 + .../minimaxm2.5_fp4_mi355x_atom.sh | 2 + .../minimaxm2.5_fp8_mi355x_atom.sh | 2 + .../single_node/qwen3.5_fp8_mi355x_atom.sh | 2 + .../qwen3.5_fp8_mi355x_atom_mtp.sh | 2 + perf-changelog.yaml | 7 ++ utils/bench_serving/backend_request_func.py | 6 ++ 18 files changed, 187 insertions(+), 28 deletions(-) diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index b5b474471..182aefef7 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -161,8 +161,8 @@ jobs: - If jobs cannot be run, say exactly what you could not run and why - **Important** Modify perf-changelog.yaml for any config changes affecting performance - ## Profiling (SGLang only) - When asked to profile a config, dispatch the `profile.yml` workflow. **Only SGLang configs can be profiled** — the profiler uses SGLang's `/start_profile` and `/stop_profile` HTTP endpoints. Reject profiling requests for vLLM, TRT, or other frameworks. + ## Profiling + When asked to profile a config, dispatch the `profile.yml` workflow. SGLang, vLLM, and ATOM single-node configs can be profiled through their `/start_profile` and `/stop_profile` HTTP endpoints when the server is launched with the corresponding torch profiler directory. Reject profiling requests for TRT, disaggregated/multi-node configs, or other frameworks. **Syntax:** ``` @@ -172,9 +172,10 @@ jobs: workflow_id="profile.yml", ref="main", inputs={ - "config-key": "", + "config-key": "", "config-file": "<.github/configs/nvidia-master.yaml or amd-master.yaml>", - "conc": "" + "conc": "", + "seq-len": "<1k1k or 8k1k>" } ) ``` @@ -184,19 +185,16 @@ jobs: - Model: "deepseek" / "dsr1" → model-prefix `dsr1`; "gptoss" → `gptoss`; "qwen" → `qwen3.5` - Precision: "fp4" / "fp8" / "bf16" - Runner/hardware: "b200", "h200", "h100", "mi300x", "mi325x", "mi355x", etc. - - Framework: must be "sglang" (reject if not) + - Framework: must be "sglang", "vllm", or "atom" (reject TRT and disaggregated/multi-node) - Concurrency: "conc=N" → `"conc": "N"`. Default to `"64"` if not specified. + - Sequence length: default to `"1k1k"` unless the user asks for `"8k1k"`. - Construct the config-key as: `{model-prefix}-{precision}-{runner}-sglang` + Construct the config-key as: `{model-prefix}-{precision}-{runner}-{framework}` Choose config-file: NVIDIA runners (b200, h200, h100, gb200, gb300) → `nvidia-master.yaml`; AMD runners (mi300x, mi325x, mi355x) → `amd-master.yaml` - **Available SGLang config keys:** - NVIDIA: `dsr1-fp4-b200-sglang`, `dsr1-fp8-b200-sglang`, `dsr1-fp8-h200-sglang`, `qwen3.5-bf16-b200-sglang` - AMD: `dsr1-fp4-mi355x-sglang`, `dsr1-fp8-mi300x-sglang`, `dsr1-fp8-mi325x-sglang`, `dsr1-fp8-mi355x-sglang`, `qwen3.5-bf16-mi355x-sglang`, `qwen3.5-fp8-mi355x-sglang` - **Examples:** - - "profile sglang b200 deepseek fp4 conc=4" → `config-key: dsr1-fp4-b200-sglang`, `config-file: .github/configs/nvidia-master.yaml`, `conc: 4` - - "profile sglang mi355x dsr1 fp8" → `config-key: dsr1-fp8-mi355x-sglang`, `config-file: .github/configs/amd-master.yaml`, `conc: 64` + - "profile sglang b200 deepseek fp4 conc=4" → `config-key: dsr1-fp4-b200-sglang`, `config-file: .github/configs/nvidia-master.yaml`, `conc: 4`, `seq-len: 1k1k` + - "profile atom mi355x dsv4 fp4 conc=4 8k1k" → `config-key: dsv4-fp4-mi355x-atom`, `config-file: .github/configs/amd-master.yaml`, `conc: 4`, `seq-len: 8k1k` **After dispatch:** Monitor with `mcp__github__get_workflow_run`. The profile workflow takes ~15-30 minutes. When complete, the **Perfetto relay link** is in the workflow run's step summary. Retrieve it with: diff --git a/.github/workflows/profile.yml b/.github/workflows/profile.yml index 8152d47a5..56894b84d 100644 --- a/.github/workflows/profile.yml +++ b/.github/workflows/profile.yml @@ -17,6 +17,14 @@ on: required: false type: string default: '64' + seq-len: + description: "Sequence length config to profile" + required: false + type: choice + options: + - 1k1k + - 8k1k + default: 1k1k moe-debug: description: "Enable MoE debug patch and log (MOE_DEBUG_LOG)" required: false @@ -54,7 +62,7 @@ jobs: name: Generate matrix via script run: | pip install pydantic - CLI_ARGS="test-config --config-files ${{ inputs.config-file }} --config-keys ${{ inputs.config-key }} --conc ${{ inputs.conc }}" + CLI_ARGS="test-config --config-files ${{ inputs.config-file }} --config-keys ${{ inputs.config-key }} --conc ${{ inputs.conc }} --seq-lens ${{ inputs.seq-len }}" CONFIG_JSON=$(python3 ${GITHUB_WORKSPACE}/utils/matrix_logic/generate_sweep_configs.py $CLI_ARGS) echo "raw=$CONFIG_JSON" >> $GITHUB_OUTPUT @@ -148,13 +156,14 @@ jobs: ref: ${{ inputs.ref || github.sha }} clean: false - - name: Launch + Profile (single-node sglang/vllm) + - name: Launch + Profile (single-node) id: run env: RUNNER_NAME: ${{ runner.name }} PROFILE: '1' SGLANG_TORCH_PROFILER_DIR: /workspace/ VLLM_TORCH_PROFILER_DIR: /workspace/ + ATOM_TORCH_PROFILER_DIR: /workspace/atom_profiles VLLM_RPC_TIMEOUT: '1800000' shell: bash run: | @@ -193,16 +202,30 @@ jobs: fi else echo "Profile trace not found: $trace_path" >&2 + exit 1 fi - name: Process result (json -> agg) + continue-on-error: true env: RUNNER_TYPE: ${{ matrix.config.runner }} run: | python3 utils/process_result.py + - name: Upload profile diagnostics + if: ${{ always() && env.RESULT_FILENAME != '' }} + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 + with: + name: profile_diagnostics_${{ env.RESULT_FILENAME }} + path: | + ${{ env.RESULT_FILENAME }}.json + agg_${{ env.RESULT_FILENAME }}.json + server.log + gpu_metrics.csv + if-no-files-found: ignore + - name: Upload profile as artifact - if: ${{ steps.run.outputs.trace != '' }} + if: ${{ always() && steps.run.outputs.trace != '' }} uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: profile_${{ env.RESULT_FILENAME }} @@ -210,7 +233,7 @@ jobs: if-no-files-found: ignore - name: Upload TP-0-DECODE trace as artifact - if: ${{ steps.run.outputs.tp0_decode != '' }} + if: ${{ always() && steps.run.outputs.tp0_decode != '' }} uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: profile_${{ env.RESULT_FILENAME }}_TP0_DECODE @@ -218,7 +241,7 @@ jobs: if-no-files-found: ignore - name: Upload TP-0-EXTEND trace as artifact - if: ${{ steps.run.outputs.tp0_extend != '' }} + if: ${{ always() && steps.run.outputs.tp0_extend != '' }} uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: profile_${{ env.RESULT_FILENAME }}_TP0_EXTEND diff --git a/benchmarks/benchmark_lib.sh b/benchmarks/benchmark_lib.sh index 268745735..bad0e0c7b 100644 --- a/benchmarks/benchmark_lib.sh +++ b/benchmarks/benchmark_lib.sh @@ -327,10 +327,12 @@ run_benchmark_serving() { # and cap num_prompts to keep traces small. local profile_flag=() if [[ "${PROFILE:-}" == "1" ]]; then - local _prof_dir="${SGLANG_TORCH_PROFILER_DIR:-${VLLM_TORCH_PROFILER_DIR:-}}" - if [[ -n "$_prof_dir" ]]; then - mkdir -p "$_prof_dir" - fi + local _prof_dir="" + for _prof_dir in "${SGLANG_TORCH_PROFILER_DIR:-}" "${VLLM_TORCH_PROFILER_DIR:-}" "${ATOM_TORCH_PROFILER_DIR:-}"; do + if [[ -n "$_prof_dir" ]]; then + mkdir -p "$_prof_dir" + fi + done profile_flag+=(--profile) num_prompts="$max_concurrency" fi @@ -415,6 +417,15 @@ run_benchmark_serving() { # Profiling trace helpers # -------------------------------- +setup_atom_profile_args() { + ATOM_PROFILE_ARGS=() + if [[ "${PROFILE:-}" == "1" ]]; then + ATOM_TORCH_PROFILER_DIR=${ATOM_TORCH_PROFILER_DIR:-/workspace/atom_profiles} + mkdir -p "$ATOM_TORCH_PROFILER_DIR" + ATOM_PROFILE_ARGS+=(--torch-profiler-dir "$ATOM_TORCH_PROFILER_DIR") + fi +} + _find_latest_profile_trace() { local latest="" local dir="" candidate="" base="" @@ -424,6 +435,9 @@ _find_latest_profile_trace() { search_roots=() if [[ -d "$dir" ]]; then search_roots+=("$dir") + while IFS= read -r -d '' candidate; do + search_roots+=("$candidate") + done < <(find "$dir" -mindepth 1 -maxdepth 1 -type d -print0 2>/dev/null) fi if [[ -d "$dir/profiles" ]]; then search_roots+=("$dir/profiles") @@ -463,11 +477,12 @@ move_profile_trace_for_relay() { local sglang_dir="${SGLANG_TORCH_PROFILER_DIR:-/workspace}" local vllm_dir="${VLLM_TORCH_PROFILER_DIR:-/workspace}" + local atom_dir="${ATOM_TORCH_PROFILER_DIR:-/workspace}" local -a search_dirs=() local dir="" existing="" local seen=0 - for dir in "$sglang_dir" "$vllm_dir" "/workspace"; do + for dir in "$sglang_dir" "$vllm_dir" "$atom_dir" "/workspace"; do if [[ -z "$dir" ]]; then continue fi diff --git a/benchmarks/single_node/dsr1_fp4_mi355x_atom.sh b/benchmarks/single_node/dsr1_fp4_mi355x_atom.sh index 31554fc22..7adee9f21 100644 --- a/benchmarks/single_node/dsr1_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/dsr1_fp4_mi355x_atom.sh @@ -48,12 +48,14 @@ start_gpu_monitor set -x BLOCK_SIZE=${BLOCK_SIZE:-16} +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ -tp $TP \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ - --block-size $BLOCK_SIZE > $SERVER_LOG 2>&1 & + --block-size $BLOCK_SIZE \ + "${ATOM_PROFILE_ARGS[@]}" > $SERVER_LOG 2>&1 & SERVER_PID=$! @@ -80,4 +82,4 @@ fi # Stop GPU monitoring stop_gpu_monitor -set +x \ No newline at end of file +set +x diff --git a/benchmarks/single_node/dsr1_fp4_mi355x_atom_mtp.sh b/benchmarks/single_node/dsr1_fp4_mi355x_atom_mtp.sh index 1d557684e..9da5c778d 100644 --- a/benchmarks/single_node/dsr1_fp4_mi355x_atom_mtp.sh +++ b/benchmarks/single_node/dsr1_fp4_mi355x_atom_mtp.sh @@ -49,12 +49,14 @@ set -x export AMDGCN_USE_BUFFER_OPS=1 +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ -tp $TP \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ --method mtp \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/benchmarks/single_node/dsr1_fp8_mi355x_atom.sh b/benchmarks/single_node/dsr1_fp8_mi355x_atom.sh index 31554fc22..7adee9f21 100644 --- a/benchmarks/single_node/dsr1_fp8_mi355x_atom.sh +++ b/benchmarks/single_node/dsr1_fp8_mi355x_atom.sh @@ -48,12 +48,14 @@ start_gpu_monitor set -x BLOCK_SIZE=${BLOCK_SIZE:-16} +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ -tp $TP \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ - --block-size $BLOCK_SIZE > $SERVER_LOG 2>&1 & + --block-size $BLOCK_SIZE \ + "${ATOM_PROFILE_ARGS[@]}" > $SERVER_LOG 2>&1 & SERVER_PID=$! @@ -80,4 +82,4 @@ fi # Stop GPU monitoring stop_gpu_monitor -set +x \ No newline at end of file +set +x diff --git a/benchmarks/single_node/dsr1_fp8_mi355x_atom_mtp.sh b/benchmarks/single_node/dsr1_fp8_mi355x_atom_mtp.sh index 69179cec0..ea5bbc5b1 100644 --- a/benchmarks/single_node/dsr1_fp8_mi355x_atom_mtp.sh +++ b/benchmarks/single_node/dsr1_fp8_mi355x_atom_mtp.sh @@ -47,6 +47,7 @@ start_gpu_monitor set -x +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ @@ -54,6 +55,7 @@ python3 -m atom.entrypoints.openai_server \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ --method mtp \ --num-speculative-tokens 3 \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh index bf74fe40b..c27bd200c 100644 --- a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh @@ -268,6 +268,31 @@ if marker not in source: out = (exp_scores / denom.clamp(min=1e-30)).matmul(kv_f32) return out.view(1, 1, H, D).to(out_dtype) + if M == 1: + valid_2d = topk_idxs[:, 0] != -1 + safe_idxs_2d = topk_idxs[:, 0].clamp(min=0).long() + batch_idx = torch.arange(B, device=device).view(B, 1).expand_as(safe_idxs_2d) + kv_f32 = kv[batch_idx, safe_idxs_2d].float() + kv_f32 = torch.where( + valid_2d.unsqueeze(-1), + kv_f32, + torch.zeros((), dtype=kv_f32.dtype, device=device), + ) + q_f32 = q[:, 0].float() + scores = torch.einsum("bhd,bkd->bhk", q_f32, kv_f32) * float(softmax_scale) + scores = scores.masked_fill(~valid_2d.unsqueeze(1), float("-inf")) + sink = attn_sink.float().view(1, H, 1) + cmax = torch.maximum(scores.amax(dim=-1, keepdim=True), sink) + cmax = torch.where( + cmax == float("-inf"), + torch.zeros((), dtype=cmax.dtype, device=device), + cmax, + ) + exp_scores = (scores - cmax).exp() + denom = exp_scores.sum(dim=-1, keepdim=True) + (sink - cmax).exp() + out = torch.einsum("bhk,bkd->bhd", exp_scores / denom.clamp(min=1e-30), kv_f32) + return out.view(B, 1, H, D).to(out_dtype) + # ----- Gather KV per query position ----- """ if old not in source: @@ -630,6 +655,65 @@ index 46cf1b0..0d84c78 100644 + q_all = q_all * torch.rsqrt(q_all.square().mean(-1, keepdim=True) + self.eps) + kv_all = self.kv_norm(self.wkv(x)).view(total_tokens, self.head_dim) + ++ if seq_meta and all( ++ end - start == 1 and start_pos > 0 ++ for start, end, start_pos, _ in seq_meta ++ ): ++ q_chunks = [] ++ kv_cache_chunks = [] ++ topk_chunks = [] ++ freqs_chunks = [] ++ for start, end, start_pos, cache_slot in seq_meta: ++ freqs_cis = self.freqs_cis[start_pos : start_pos + 1] ++ slot = slice(cache_slot, cache_slot + 1) ++ ++ q = q_all[start:end].unsqueeze(0) ++ _apply_rotary_emb(q[..., -rd:], freqs_cis) ++ kv = kv_all[start:end].unsqueeze(0) ++ _apply_rotary_emb(kv[..., -rd:], freqs_cis) ++ act_quant_inplace(kv[..., :-rd], 64, self.scale_fmt) ++ ++ topk_idxs = _get_window_topk_idxs( ++ win, 1, 1, start_pos, device=x.device ++ ) ++ if self.compress_ratio: ++ offset = win ++ if self.indexer is not None: ++ compress_topk_idxs = self.indexer( ++ x[start:end], qr_all[start:end], start_pos, offset, cache_slot ++ ) ++ else: ++ compress_topk_idxs = _get_compress_topk_idxs( ++ ratio, 1, 1, start_pos, offset, device=x.device ++ ) ++ topk_idxs = torch.cat([topk_idxs, compress_topk_idxs], dim=-1) ++ ++ self.kv_cache[slot, start_pos % win] = kv.squeeze(1) ++ if self.compress_ratio: ++ self.compressor(x[start:end], start_pos, cache_slot) ++ ++ q_chunks.append(q) ++ kv_cache_chunks.append(self.kv_cache[slot]) ++ topk_chunks.append(topk_idxs.int()) ++ freqs_chunks.append(freqs_cis) ++ ++ o = sparse_attn( ++ torch.cat(q_chunks, dim=0), ++ torch.cat(kv_cache_chunks, dim=0), ++ self.attn_sink, ++ torch.cat(topk_chunks, dim=0), ++ self.softmax_scale, ++ ) ++ out_chunks = [] ++ for idx, freqs_cis in enumerate(freqs_chunks): ++ o_i = o[idx : idx + 1] ++ _apply_rotary_emb(o_i[..., -rd:], freqs_cis, inverse=True) ++ out_chunks.append(o_i.squeeze(0)) ++ o = torch.cat(out_chunks, dim=0).view(total_tokens, self.n_local_groups, -1) ++ wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1) ++ o = torch.einsum("sgd,grd->sgr", o, wo_a) ++ return self.wo_b(o.flatten(1)) ++ + outputs = [] + for start, end, start_pos, cache_slot in seq_meta: + seqlen = end - start @@ -1024,6 +1108,7 @@ MAX_NUM_SEQS=$(( CONC < 4 ? 4 : CONC )) # sparse-attention problem, while MoE/FFN still sees the larger token batch. DEFAULT_MAX_NUM_BATCHED_TOKENS=$(( MAX_MODEL_LEN_VALUE > 16384 ? MAX_MODEL_LEN_VALUE : 16384 )) MAX_NUM_BATCHED_TOKENS=${MAX_NUM_BATCHED_TOKENS:-$DEFAULT_MAX_NUM_BATCHED_TOKENS} +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ @@ -1033,6 +1118,7 @@ python3 -m atom.entrypoints.openai_server \ --enforce-eager \ --max-num-seqs $MAX_NUM_SEQS \ --max-num-batched-tokens $MAX_NUM_BATCHED_TOKENS \ + "${ATOM_PROFILE_ARGS[@]}" \ --trust-remote-code > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/benchmarks/single_node/glm5.1_fp4_mi355x_atom.sh b/benchmarks/single_node/glm5.1_fp4_mi355x_atom.sh index 036346af3..410743d1b 100644 --- a/benchmarks/single_node/glm5.1_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/glm5.1_fp4_mi355x_atom.sh @@ -43,6 +43,7 @@ MEM_FRAC_STATIC=0.9 set -x pip install -U transformers +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ @@ -51,6 +52,7 @@ python3 -m atom.entrypoints.openai_server \ --gpu-memory-utilization $MEM_FRAC_STATIC \ --default-chat-template-kwargs '{"enable_thinking": false}' \ --trust-remote-code \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/benchmarks/single_node/glm5_fp8_mi355x_atom.sh b/benchmarks/single_node/glm5_fp8_mi355x_atom.sh index 31bc8b25f..b398a4ea4 100644 --- a/benchmarks/single_node/glm5_fp8_mi355x_atom.sh +++ b/benchmarks/single_node/glm5_fp8_mi355x_atom.sh @@ -42,6 +42,7 @@ start_gpu_monitor set -x pip install -U transformers +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ @@ -49,6 +50,7 @@ python3 -m atom.entrypoints.openai_server \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ --default-chat-template-kwargs '{"enable_thinking": false}' \ --trust-remote-code \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/benchmarks/single_node/gptoss_fp4_mi355x_atom.sh b/benchmarks/single_node/gptoss_fp4_mi355x_atom.sh index 76bc87c0c..455444370 100644 --- a/benchmarks/single_node/gptoss_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/gptoss_fp4_mi355x_atom.sh @@ -49,12 +49,14 @@ set -x BLOCK_SIZE=${BLOCK_SIZE:-16} export ATOM_GPT_OSS_MODEL=1 #TODO remove this +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ -tp $TP \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ - --block-size $BLOCK_SIZE > $SERVER_LOG 2>&1 & + --block-size $BLOCK_SIZE \ + "${ATOM_PROFILE_ARGS[@]}" > $SERVER_LOG 2>&1 & SERVER_PID=$! @@ -81,4 +83,4 @@ fi # Stop GPU monitoring stop_gpu_monitor -set +x \ No newline at end of file +set +x diff --git a/benchmarks/single_node/kimik2.5_fp4_mi355x_atom.sh b/benchmarks/single_node/kimik2.5_fp4_mi355x_atom.sh index ca84f8228..0eeb3e6aa 100755 --- a/benchmarks/single_node/kimik2.5_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/kimik2.5_fp4_mi355x_atom.sh @@ -42,12 +42,14 @@ start_gpu_monitor set -x +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ -tp $TP \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ --trust-remote-code \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/benchmarks/single_node/minimaxm2.5_fp4_mi355x_atom.sh b/benchmarks/single_node/minimaxm2.5_fp4_mi355x_atom.sh index ca84f8228..0eeb3e6aa 100644 --- a/benchmarks/single_node/minimaxm2.5_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/minimaxm2.5_fp4_mi355x_atom.sh @@ -42,12 +42,14 @@ start_gpu_monitor set -x +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ -tp $TP \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ --trust-remote-code \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/benchmarks/single_node/minimaxm2.5_fp8_mi355x_atom.sh b/benchmarks/single_node/minimaxm2.5_fp8_mi355x_atom.sh index ca84f8228..0eeb3e6aa 100755 --- a/benchmarks/single_node/minimaxm2.5_fp8_mi355x_atom.sh +++ b/benchmarks/single_node/minimaxm2.5_fp8_mi355x_atom.sh @@ -42,12 +42,14 @@ start_gpu_monitor set -x +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ -tp $TP \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ --trust-remote-code \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/benchmarks/single_node/qwen3.5_fp8_mi355x_atom.sh b/benchmarks/single_node/qwen3.5_fp8_mi355x_atom.sh index 2a8c67da0..f9bb0d5cd 100644 --- a/benchmarks/single_node/qwen3.5_fp8_mi355x_atom.sh +++ b/benchmarks/single_node/qwen3.5_fp8_mi355x_atom.sh @@ -43,6 +43,7 @@ MEM_FRAC_STATIC=0.9 set -x +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ @@ -50,6 +51,7 @@ python3 -m atom.entrypoints.openai_server \ --kv_cache_dtype fp8 $CALCULATED_MAX_MODEL_LEN $EP \ --gpu-memory-utilization $MEM_FRAC_STATIC \ --trust-remote-code \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/benchmarks/single_node/qwen3.5_fp8_mi355x_atom_mtp.sh b/benchmarks/single_node/qwen3.5_fp8_mi355x_atom_mtp.sh index 9399fe792..8110ac124 100644 --- a/benchmarks/single_node/qwen3.5_fp8_mi355x_atom_mtp.sh +++ b/benchmarks/single_node/qwen3.5_fp8_mi355x_atom_mtp.sh @@ -43,6 +43,7 @@ MEM_FRAC_STATIC=0.9 set -x +setup_atom_profile_args python3 -m atom.entrypoints.openai_server \ --model $MODEL \ --server-port $PORT \ @@ -52,6 +53,7 @@ python3 -m atom.entrypoints.openai_server \ --method mtp \ --num-speculative-tokens 3 \ --trust-remote-code \ + "${ATOM_PROFILE_ARGS[@]}" \ > $SERVER_LOG 2>&1 & SERVER_PID=$! diff --git a/perf-changelog.yaml b/perf-changelog.yaml index c63086425..4178197e2 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -2022,3 +2022,10 @@ - "Expand MI355X ATOM DSv4 sweep from CONC=1 only to CONC=1/2/4 for 1k1k and 8k1k" - "Bump the overlaid ROCm/ATOM#650 SHA to af17eb8; the patch unblocks correctness but still runs DSv4 requests sequentially until upstream vectorized sparse-attention/cache support lands" pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/XXX + +- config-keys: + - dsv4-fp4-mi355x-atom + description: + - "Optimize DSv4 ATOM CONC=4 decode path by batching one-token sparse-attention calls across active requests inside the local ROCm/ATOM#650 overlay" + - "Add a vectorized B>1,M=1 sparse_attn_v4 torch fallback so steady-state decode avoids one sparse-attention launch sequence per request per layer" + pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/XXX diff --git a/utils/bench_serving/backend_request_func.py b/utils/bench_serving/backend_request_func.py index af030720e..0900a8e6b 100644 --- a/utils/bench_serving/backend_request_func.py +++ b/utils/bench_serving/backend_request_func.py @@ -273,6 +273,12 @@ async def async_request_openai_completions( async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: + if api_url.endswith("profile"): + output.latency = time.perf_counter() - st + output.generated_text = await response.text() + output.success = True + return output + first_chunk_received = False async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() From 6f4600a924dc6ea713ccead9d499273454bbb341 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Thu, 30 Apr 2026 11:26:41 -0700 Subject: [PATCH 06/14] Constrain DSv4 ATOM profile window --- .github/workflows/profile.yml | 2 + benchmarks/benchmark_lib.sh | 5 +- .../single_node/dsv4_fp4_mi355x_atom.sh | 84 ------------------- perf-changelog.yaml | 7 ++ utils/bench_serving/benchmark_serving.py | 3 +- 5 files changed, 14 insertions(+), 87 deletions(-) diff --git a/.github/workflows/profile.yml b/.github/workflows/profile.yml index 56894b84d..b9d226b2c 100644 --- a/.github/workflows/profile.yml +++ b/.github/workflows/profile.yml @@ -164,6 +164,8 @@ jobs: SGLANG_TORCH_PROFILER_DIR: /workspace/ VLLM_TORCH_PROFILER_DIR: /workspace/ ATOM_TORCH_PROFILER_DIR: /workspace/atom_profiles + PROFILE_NUM_STEPS: '1' + PROFILE_OUTPUT_LEN: '1' VLLM_RPC_TIMEOUT: '1800000' shell: bash run: | diff --git a/benchmarks/benchmark_lib.sh b/benchmarks/benchmark_lib.sh index bad0e0c7b..41fc891b3 100644 --- a/benchmarks/benchmark_lib.sh +++ b/benchmarks/benchmark_lib.sh @@ -324,7 +324,7 @@ run_benchmark_serving() { fi # Profiling support: when PROFILE=1, ensure profiler dir exists, add --profile flag, - # and cap num_prompts to keep traces small. + # and cap the run to a tiny one-step window by default. local profile_flag=() if [[ "${PROFILE:-}" == "1" ]]; then local _prof_dir="" @@ -334,7 +334,8 @@ run_benchmark_serving() { fi done profile_flag+=(--profile) - num_prompts="$max_concurrency" + num_prompts="${PROFILE_NUM_PROMPTS:-$max_concurrency}" + output_len="${PROFILE_OUTPUT_LEN:-${PROFILE_NUM_STEPS:-1}}" fi # Build benchmark command diff --git a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh index c27bd200c..2d420853d 100644 --- a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh @@ -268,31 +268,6 @@ if marker not in source: out = (exp_scores / denom.clamp(min=1e-30)).matmul(kv_f32) return out.view(1, 1, H, D).to(out_dtype) - if M == 1: - valid_2d = topk_idxs[:, 0] != -1 - safe_idxs_2d = topk_idxs[:, 0].clamp(min=0).long() - batch_idx = torch.arange(B, device=device).view(B, 1).expand_as(safe_idxs_2d) - kv_f32 = kv[batch_idx, safe_idxs_2d].float() - kv_f32 = torch.where( - valid_2d.unsqueeze(-1), - kv_f32, - torch.zeros((), dtype=kv_f32.dtype, device=device), - ) - q_f32 = q[:, 0].float() - scores = torch.einsum("bhd,bkd->bhk", q_f32, kv_f32) * float(softmax_scale) - scores = scores.masked_fill(~valid_2d.unsqueeze(1), float("-inf")) - sink = attn_sink.float().view(1, H, 1) - cmax = torch.maximum(scores.amax(dim=-1, keepdim=True), sink) - cmax = torch.where( - cmax == float("-inf"), - torch.zeros((), dtype=cmax.dtype, device=device), - cmax, - ) - exp_scores = (scores - cmax).exp() - denom = exp_scores.sum(dim=-1, keepdim=True) + (sink - cmax).exp() - out = torch.einsum("bhk,bkd->bhd", exp_scores / denom.clamp(min=1e-30), kv_f32) - return out.view(B, 1, H, D).to(out_dtype) - # ----- Gather KV per query position ----- """ if old not in source: @@ -655,65 +630,6 @@ index 46cf1b0..0d84c78 100644 + q_all = q_all * torch.rsqrt(q_all.square().mean(-1, keepdim=True) + self.eps) + kv_all = self.kv_norm(self.wkv(x)).view(total_tokens, self.head_dim) + -+ if seq_meta and all( -+ end - start == 1 and start_pos > 0 -+ for start, end, start_pos, _ in seq_meta -+ ): -+ q_chunks = [] -+ kv_cache_chunks = [] -+ topk_chunks = [] -+ freqs_chunks = [] -+ for start, end, start_pos, cache_slot in seq_meta: -+ freqs_cis = self.freqs_cis[start_pos : start_pos + 1] -+ slot = slice(cache_slot, cache_slot + 1) -+ -+ q = q_all[start:end].unsqueeze(0) -+ _apply_rotary_emb(q[..., -rd:], freqs_cis) -+ kv = kv_all[start:end].unsqueeze(0) -+ _apply_rotary_emb(kv[..., -rd:], freqs_cis) -+ act_quant_inplace(kv[..., :-rd], 64, self.scale_fmt) -+ -+ topk_idxs = _get_window_topk_idxs( -+ win, 1, 1, start_pos, device=x.device -+ ) -+ if self.compress_ratio: -+ offset = win -+ if self.indexer is not None: -+ compress_topk_idxs = self.indexer( -+ x[start:end], qr_all[start:end], start_pos, offset, cache_slot -+ ) -+ else: -+ compress_topk_idxs = _get_compress_topk_idxs( -+ ratio, 1, 1, start_pos, offset, device=x.device -+ ) -+ topk_idxs = torch.cat([topk_idxs, compress_topk_idxs], dim=-1) -+ -+ self.kv_cache[slot, start_pos % win] = kv.squeeze(1) -+ if self.compress_ratio: -+ self.compressor(x[start:end], start_pos, cache_slot) -+ -+ q_chunks.append(q) -+ kv_cache_chunks.append(self.kv_cache[slot]) -+ topk_chunks.append(topk_idxs.int()) -+ freqs_chunks.append(freqs_cis) -+ -+ o = sparse_attn( -+ torch.cat(q_chunks, dim=0), -+ torch.cat(kv_cache_chunks, dim=0), -+ self.attn_sink, -+ torch.cat(topk_chunks, dim=0), -+ self.softmax_scale, -+ ) -+ out_chunks = [] -+ for idx, freqs_cis in enumerate(freqs_chunks): -+ o_i = o[idx : idx + 1] -+ _apply_rotary_emb(o_i[..., -rd:], freqs_cis, inverse=True) -+ out_chunks.append(o_i.squeeze(0)) -+ o = torch.cat(out_chunks, dim=0).view(total_tokens, self.n_local_groups, -1) -+ wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1) -+ o = torch.einsum("sgd,grd->sgr", o, wo_a) -+ return self.wo_b(o.flatten(1)) -+ + outputs = [] + for start, end, start_pos, cache_slot in seq_meta: + seqlen = end - start diff --git a/perf-changelog.yaml b/perf-changelog.yaml index 0e738d734..b23b04d3d 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -2046,3 +2046,10 @@ - "Optimize DSv4 ATOM CONC=4 decode path by batching one-token sparse-attention calls across active requests inside the local ROCm/ATOM#650 overlay" - "Add a vectorized B>1,M=1 sparse_attn_v4 torch fallback so steady-state decode avoids one sparse-attention launch sequence per request per layer" pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/XXX + +- config-keys: + - dsv4-fp4-mi355x-atom + description: + - "Revert the experimental DSv4 ATOM batched-decode sparse-attention overlay after it hung MI355X profile/sweep jobs" + - "Keep profile workflow support, but force PROFILE_NUM_STEPS=1 and PROFILE_OUTPUT_LEN=1 so profiler runs capture a one-token window" + pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/XXX diff --git a/utils/bench_serving/benchmark_serving.py b/utils/bench_serving/benchmark_serving.py index 68887c59b..9ed7d40e8 100644 --- a/utils/bench_serving/benchmark_serving.py +++ b/utils/bench_serving/benchmark_serving.py @@ -532,13 +532,14 @@ async def warmup_limited_req_fn(): if profile: print("Starting profiler...") + profile_num_steps = int(os.environ.get("PROFILE_NUM_STEPS", "1")) profile_input = RequestFuncInput(model=model_id, model_name=model_name, prompt=test_prompt, api_url=base_url + "/start_profile", prompt_len=test_prompt_len, output_len=test_output_len, - extra_body={"num_steps": 1, "merge_profiles": True, "profile_by_stage": True}, + extra_body={"num_steps": profile_num_steps, "merge_profiles": True, "profile_by_stage": True}, logprobs=logprobs, best_of=best_of, multi_modal_content=test_mm_content, From c3229d342dba3c95d83c86597b24c36e23d4ae42 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Thu, 30 Apr 2026 15:54:37 -0700 Subject: [PATCH 07/14] atom --- .../single_node/dsv4_fp4_mi355x_atom.sh | 221 ++++++++++++++++-- perf-changelog.yaml | 6 +- 2 files changed, 210 insertions(+), 17 deletions(-) diff --git a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh index 2d420853d..d1ae37e08 100644 --- a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh @@ -60,6 +60,8 @@ export AITER_LOG_LEVEL=WARNING # * sunway513/aiter@e450e4d adds DSv4 FP4 MoE tuned rows that route # eligible token counts to FlyDSL FP4 MoE kernels instead of default CK # heuristics when the image has the optional flydsl package. +# * Oseltamivir/aiter@083a837 adds DSv4 sparse MQA sink and Indexer +# scorer/top-k Triton ops so ATOM can avoid the PR650 Torch fallback. # # The open performance PRs cherry-pick cleanly over the pinned main SHA as # of 2026-04-29. @@ -78,6 +80,10 @@ if [ "${AITER_DSV4_PERF_STACK:-1}" = "1" ]; then AITER_DSV4_TUNED_FMOE_REPO=${AITER_DSV4_TUNED_FMOE_REPO:-https://github.com/sunway513/aiter.git} AITER_DSV4_TUNED_FMOE_SHA=${AITER_DSV4_TUNED_FMOE_SHA:-e450e4deb992c5ecd9db5ef5ef79f1d40208bc9c} AITER_DSV4_TUNED_FMOE_PATH=${AITER_DSV4_TUNED_FMOE_PATH:-aiter/configs/model_configs/dsv4_fp4_tuned_fmoe.csv} + AITER_DSV4_SPARSE_INDEXER=${AITER_DSV4_SPARSE_INDEXER:-1} + AITER_DSV4_SPARSE_INDEXER_REPO=${AITER_DSV4_SPARSE_INDEXER_REPO:-https://github.com/Oseltamivir/aiter.git} + AITER_DSV4_SPARSE_INDEXER_REF=${AITER_DSV4_SPARSE_INDEXER_REF:-dsv4-sparse-indexer} + AITER_DSV4_SPARSE_INDEXER_SHA=${AITER_DSV4_SPARSE_INDEXER_SHA:-083a837de5c44080b18b18682f2e7f611717a06b} rm -rf "$AITER_PERF_DIR" git clone --filter=blob:none "$AITER_PERF_REPO" "$AITER_PERF_DIR" @@ -104,6 +110,41 @@ if [ "${AITER_DSV4_PERF_STACK:-1}" = "1" ]; then || { echo "FATAL: DSv4 FP4 tuned fMoE rows not found in $AITER_DSV4_TUNED_FMOE_PATH"; exit 1; } fi + if [ "$AITER_DSV4_SPARSE_INDEXER" = "1" ]; then + git fetch --depth=1 "$AITER_DSV4_SPARSE_INDEXER_REPO" "$AITER_DSV4_SPARSE_INDEXER_REF" + test "$(git rev-parse FETCH_HEAD)" = "$AITER_DSV4_SPARSE_INDEXER_SHA" + for file_path in \ + aiter/ops/triton/_triton_kernels/attention/dsv4_indexer.py \ + aiter/ops/triton/_triton_kernels/attention/sparse_mqa_sink.py \ + aiter/ops/triton/attention/dsv4_indexer.py \ + aiter/ops/triton/attention/sparse_mqa_sink.py \ + op_tests/test_dsv4_indexer.py \ + op_tests/test_sparse_mqa_sink.py + do + mkdir -p "$(dirname "$file_path")" + git show "FETCH_HEAD:$file_path" > "$file_path" + done + python3 - <<'PYEOF' +from pathlib import Path + +path = Path("aiter/ops/triton/__init__.py") +source = path.read_text() +needle = ' "prefill_attention": "attention.prefill_attention",\n' +insert = ' "dsv4_indexer": "attention.dsv4_indexer",\n "sparse_mqa_sink": "attention.sparse_mqa_sink",\n' +if '"dsv4_indexer": "attention.dsv4_indexer"' not in source: + if needle not in source: + raise SystemExit("FATAL: aiter triton __init__.py missing attention map anchor") + source = source.replace(needle, needle + insert, 1) +elif '"sparse_mqa_sink": "attention.sparse_mqa_sink"' not in source: + source = source.replace( + ' "dsv4_indexer": "attention.dsv4_indexer",\n', + insert, + 1, + ) +path.write_text(source) +PYEOF + fi + if [ ! -d 3rdparty/composable_kernel/include ]; then git submodule update --init --recursive --depth=1 3rdparty/composable_kernel \ || git submodule update --init --recursive 3rdparty/composable_kernel @@ -142,6 +183,12 @@ required = { ).exists(), "MXFP4 scaleN_pad fix": "scaleN_pad" in fp4_utils, "DSv4 FP4 tuned fMoE config": dsv4_tuned_fmoe is None or dsv4_tuned_fmoe.exists(), + "DSv4 sparse MQA sink Triton op": ( + root / "ops" / "triton" / "attention" / "sparse_mqa_sink.py" + ).exists(), + "DSv4 Indexer Triton op": ( + root / "ops" / "triton" / "attention" / "dsv4_indexer.py" + ).exists(), } missing = [name for name, ok in required.items() if not ok] if missing: @@ -181,22 +228,25 @@ else echo "WARN: AITER_DSV4_PERF_STACK=0; using image-provided aiter" fi -# Apply ROCm/ATOM#650 (DSv4 PR1 skeleton) over the image's wheel-installed -# atom. The chosen base image ships atom as a built wheel, not editable, so -# we overlay an editable install from the PR branch at a pinned SHA. Bump -# this SHA when the PR moves; do not track the branch tip (the run becomes -# a moving target if the branch is force-pushed). -ATOM_PR_SHA="af17eb89ceb6370b0c1724aef3bf938e6baedecd" +# Apply an ATOM DSv4 overlay over the image's wheel-installed atom. The default +# fork commit is ROCm/ATOM#650 at af17eb8 plus the local multi-request cache-slot +# fix and AITER sparse_attn/Indexer dispatch below. Keep ATOM_PR_* overridable so +# this can be pointed back at ROCm/ATOM#650 while debugging upstream movement. +ATOM_PR_REPO=${ATOM_PR_REPO:-https://github.com/Oseltamivir/ATOM.git} +ATOM_PR_REF=${ATOM_PR_REF:-dsv4-aiter-sparse-indexer} +ATOM_PR_SHA=${ATOM_PR_SHA:-0ddf9a9a7919631a9a89073d624bd25b16014f17} export ATOM_PR_DIR="/tmp/atom-pr650" if [ ! -d "$ATOM_PR_DIR/.git" ]; then - git clone --filter=blob:none https://github.com/ROCm/ATOM.git "$ATOM_PR_DIR" + git clone --filter=blob:none "$ATOM_PR_REPO" "$ATOM_PR_DIR" fi ( cd "$ATOM_PR_DIR" + git remote set-url origin "$ATOM_PR_REPO" # Try a targeted fetch first (fast); fall back to fetching the PR ref if # the server doesn't allow fetching the SHA directly. git fetch --depth=1 origin "$ATOM_PR_SHA" 2>/dev/null \ + || git fetch --depth=1 origin "$ATOM_PR_REF" 2>/dev/null \ || git fetch --depth=1 origin pull/650/head git checkout --force "$ATOM_PR_SHA" test "$(git rev-parse HEAD)" = "$ATOM_PR_SHA" @@ -208,12 +258,8 @@ fi || { echo "FATAL: ATOM DSv4 mhc_pre aiter hook not found"; exit 1; } # ROCm/ATOM#650 sparse_attn_v4.py is a correctness-first torch fallback. - # Add two local mitigations while we wait for a serving-compatible AITER - # sparse-attention kernel: - # 1. chunk prefill over the M dimension to keep temporary scores under - # memory pressure, making higher-conc experiments less likely to OOM; - # 2. use a B=1,M=1 decode fast path that avoids the fallback's large - # broadcast/mask/concat intermediates on every generated token. + # Route DSv4 sparse MQA through the forked AITER Triton kernel first. Keep + # the old chunk/decode mitigations only as an explicit fallback path. python3 - <<'PYEOF' from pathlib import Path @@ -234,6 +280,53 @@ if marker not in source: new = """ out_dtype = q.dtype device = q.device + if os.environ.get("ATOM_DSV4_AITER_SPARSE_ATTN", "1") == "1" and q.is_cuda: + try: + from aiter.ops.triton.attention.sparse_mqa_sink import sparse_mqa_sink + + block_size = int( + os.environ.get("ATOM_DSV4_AITER_SPARSE_ATTN_BLOCK_SIZE", "128") + or "128" + ) + q_flat = q.reshape(B * M, H, D).contiguous() + topk_flat = topk_idxs.reshape(B * M, K).contiguous().int() + num_blocks = (N + block_size - 1) // block_size + padded_n = num_blocks * block_size + if padded_n != N: + kv_padded = kv.new_zeros((B, padded_n, D)) + kv_padded[:, :N] = kv + else: + kv_padded = kv.contiguous() + kv_blocks = ( + kv_padded.view(B, num_blocks, block_size, D) + .reshape(B * num_blocks, block_size, D) + .contiguous() + ) + block_table = torch.arange( + B * num_blocks, device=device, dtype=torch.int32 + ).view(B, num_blocks) + cu_seqlens_q = torch.arange( + 0, (B + 1) * M, M, device=device, dtype=torch.int32 + ) + seqused_k = torch.full((B,), N, device=device, dtype=torch.int32) + out = torch.empty_like(q_flat) + sparse_mqa_sink( + q_flat, + kv_blocks, + out, + cu_seqlens_q, + seqused_k, + float(softmax_scale), + topk_flat, + block_table, + attn_sink.float().contiguous(), + ) + return out.view(B, M, H, D).to(out_dtype) + except Exception as exc: + if os.environ.get("ATOM_DSV4_AITER_SPARSE_ATTN_STRICT", "1") == "1": + raise + print(f"WARN: AITER DSv4 sparse_attn failed, falling back to Torch: {exc!r}") + chunk_tokens = int(os.environ.get("ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS", "0") or "0") if B == 1 and chunk_tokens > 0 and M > chunk_tokens: return torch.cat( @@ -274,9 +367,9 @@ if marker not in source: raise SystemExit("FATAL: sparse_attn_v4.py did not match expected PR650 source") source = source.replace(old, new, 1) path.write_text(source) - print(f"applied DSv4 sparse_attn_v4 decode/chunk patch: {path}") + print(f"applied DSv4 sparse_attn_v4 AITER/decode/chunk patch: {path}") else: - print(f"DSv4 sparse_attn_v4 decode/chunk patch already present: {path}") + print(f"DSv4 sparse_attn_v4 AITER/decode/chunk patch already present: {path}") PYEOF # Local multi-request overlay for ROCm/ATOM#650. ATOM's scheduler passes @@ -287,6 +380,7 @@ PYEOF # MoE/FFN layer-by-layer. This fixes correctness for CONC>1 and avoids the # worst all-layers-per-request loop until upstream vectorizes the DSv4 # sparse-attention/cache path. + if ! grep -q 'def forward_batched' atom/models/deepseek_v4.py; then sed 's/^$/ /' <<'PATCH' | git apply --recount diff --git a/atom/model_engine/llm_engine.py b/atom/model_engine/llm_engine.py index 8de9532..ddde446 100644 @@ -879,6 +973,101 @@ index 46cf1b0..0d84c78 100644 def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: # In V4, the LM head is fused into DeepseekV4Model.forward (it consumes PATCH + else + echo "ATOM DSv4 multi-request overlay already present: atom/models/deepseek_v4.py" + fi + + # Replace the Indexer's [S,H,T] Torch scorer/topk with the forked AITER + # scorer. For 1k1k, the op takes its dense causal fast path and skips + # scoring entirely because every committed compressed entry is selected. + python3 - <<'PYEOF' +from pathlib import Path + +path = Path("atom/models/deepseek_v4.py") +source = path.read_text() +marker = "ATOM_DSV4_AITER_INDEXER" +if marker not in source: + old = """ # ----- Index score ----- + index_score = torch.einsum( + "bshd,btd->bsht", q, self.kv_cache[slot, : end_pos // ratio] + ) + index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2) + + # ----- Top-k selection over compressed positions ----- + if start_pos == 0: + mask = ( + torch.arange(seqlen // ratio, device=x.device).repeat(seqlen, 1) + >= torch.arange(1, seqlen + 1, device=x.device).unsqueeze(1) // ratio + ) + index_score = index_score + torch.where(mask, float("-inf"), 0.0) + topk_idxs = index_score.topk(min(self.index_topk, end_pos // ratio), dim=-1)[1] + if start_pos == 0: + mask = ( + topk_idxs + >= torch.arange(1, seqlen + 1, device=x.device).unsqueeze(1) // ratio + ) + topk_idxs = torch.where(mask, -1, topk_idxs + offset) + else: + topk_idxs = topk_idxs + offset + return topk_idxs +""" + new = """ # ----- Index score / top-k selection over compressed positions ----- + n_committed = end_pos // ratio + if n_committed <= 0: + return torch.empty(1, seqlen, 0, dtype=torch.int32, device=x.device) + + if os.environ.get("ATOM_DSV4_AITER_INDEXER", "1") == "1" and q.is_cuda: + try: + from aiter.ops.triton.attention.dsv4_indexer import dsv4_indexer_topk + + positions = torch.arange( + start_pos, end_pos, device=x.device, dtype=torch.int64 + ) + topk_idxs = dsv4_indexer_topk( + q.squeeze(0), + self.kv_cache[slot, :n_committed].squeeze(0), + weights.squeeze(0), + positions, + self.index_topk, + offset, + ratio=ratio, + ) + return topk_idxs.unsqueeze(0) + except Exception as exc: + if os.environ.get("ATOM_DSV4_AITER_INDEXER_STRICT", "1") == "1": + raise + print(f"WARN: AITER DSv4 Indexer failed, falling back to Torch: {exc!r}") + + index_score = torch.einsum( + "bshd,btd->bsht", q, self.kv_cache[slot, :n_committed] + ) + index_score = (index_score.relu_() * weights.unsqueeze(-1)).sum(dim=2) + + if start_pos == 0: + mask = ( + torch.arange(seqlen // ratio, device=x.device).repeat(seqlen, 1) + >= torch.arange(1, seqlen + 1, device=x.device).unsqueeze(1) // ratio + ) + index_score = index_score + torch.where(mask, float("-inf"), 0.0) + topk_idxs = index_score.topk(min(self.index_topk, n_committed), dim=-1)[1] + if start_pos == 0: + mask = ( + topk_idxs + >= torch.arange(1, seqlen + 1, device=x.device).unsqueeze(1) // ratio + ) + topk_idxs = torch.where(mask, -1, topk_idxs + offset) + else: + topk_idxs = topk_idxs + offset + return topk_idxs +""" + if old not in source: + raise SystemExit("FATAL: deepseek_v4.py did not match expected Indexer fallback") + source = source.replace(old, new, 1) + path.write_text(source) + print(f"applied DSv4 AITER Indexer patch: {path}") +else: + print(f"DSv4 AITER Indexer patch already present: {path}") +PYEOF # --no-deps: don't churn the image's pinned ROCm/torch/triton/aiter. # --force-reinstall: replace the wheel-installed atom with the editable copy. @@ -1008,6 +1197,8 @@ start_gpu_monitor set -x BLOCK_SIZE=${BLOCK_SIZE:-16} +export ATOM_DSV4_AITER_SPARSE_ATTN=${ATOM_DSV4_AITER_SPARSE_ATTN:-1} +export ATOM_DSV4_AITER_INDEXER=${ATOM_DSV4_AITER_INDEXER:-1} export ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS=${ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS:-256} # --enforce-eager is required: ROCm/ATOM#650 (PR1 skeleton) has no CUDAGraph # support yet (deferred to a follow-up PR). max-num-seqs is sized to the diff --git a/perf-changelog.yaml b/perf-changelog.yaml index b23b04d3d..f86a14dd5 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -2050,6 +2050,8 @@ - config-keys: - dsv4-fp4-mi355x-atom description: - - "Revert the experimental DSv4 ATOM batched-decode sparse-attention overlay after it hung MI355X profile/sweep jobs" - - "Keep profile workflow support, but force PROFILE_NUM_STEPS=1 and PROFILE_OUTPUT_LEN=1 so profiler runs capture a one-token window" + - "Overlay Oseltamivir/aiter@083a837 DSv4 sparse MQA sink and Indexer scorer/top-k Triton kernels on top of the existing AITER FP4 perf stack" + - "Use Oseltamivir/ATOM@0ddf9a9 as the default ATOM overlay, with the local ROCm/ATOM#650 patch path retained for upstream debugging" + - "Patch the DSv4 sparse_attn_v4 and Indexer paths to use the AITER kernels before falling back to the correctness-first Torch implementations" + - "Use a dense causal Indexer fast path for 1k1k when all committed compressed entries are selected, avoiding the [S,H,T] Torch scoring tensor entirely" pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/XXX From 59ff44e2b1c8ca0f961d066a8b5f89a8d4a307ee Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Thu, 30 Apr 2026 18:05:24 -0700 Subject: [PATCH 08/14] Add DSv4 ATOM TP4 comparison --- .github/configs/amd-master.yaml | 2 ++ perf-changelog.yaml | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/.github/configs/amd-master.yaml b/.github/configs/amd-master.yaml index eaf7d36fa..778e9df27 100644 --- a/.github/configs/amd-master.yaml +++ b/.github/configs/amd-master.yaml @@ -1555,8 +1555,10 @@ dsv4-fp4-mi355x-atom: - isl: 1024 osl: 1024 search-space: + - { tp: 4, ep: 1, conc-start: 1, conc-end: 4 } - { tp: 8, ep: 1, conc-start: 1, conc-end: 4 } - isl: 8192 osl: 1024 search-space: + - { tp: 4, ep: 1, conc-start: 1, conc-end: 4 } - { tp: 8, ep: 1, conc-start: 1, conc-end: 4 } diff --git a/perf-changelog.yaml b/perf-changelog.yaml index bbbdae5c3..d77a3cb09 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -2046,3 +2046,10 @@ - "Patch the DSv4 sparse_attn_v4 and Indexer paths to use the AITER kernels before falling back to the correctness-first Torch implementations" - "Use a dense causal Indexer fast path for 1k1k when all committed compressed entries are selected, avoiding the [S,H,T] Torch scoring tensor entirely" pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1229 + +- config-keys: + - dsv4-fp4-mi355x-atom + description: + - "Add TP=4 DSv4 ATOM sweep points at conc=1/2/4 for 1k1k and 8k1k" + - "Profile data shows TP=8 DSv4 ATOM decode/prefill is dominated by aiter cross_device_reduce kernels (~49% of GPU kernel time), so TP=4 tests whether fewer ranks reduces the communication bottleneck enough to improve per-GPU throughput" + pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/XXX From 357996180e2495a7928058b6315e59c3f8236f59 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Thu, 30 Apr 2026 18:11:06 -0700 Subject: [PATCH 09/14] Rebase DSv4 ATOM overlay on PR650 head --- benchmarks/single_node/dsv4_fp4_mi355x_atom.sh | 15 ++++++++------- perf-changelog.yaml | 8 ++++++++ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh index d1ae37e08..ec9425871 100644 --- a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh @@ -229,12 +229,12 @@ else fi # Apply an ATOM DSv4 overlay over the image's wheel-installed atom. The default -# fork commit is ROCm/ATOM#650 at af17eb8 plus the local multi-request cache-slot -# fix and AITER sparse_attn/Indexer dispatch below. Keep ATOM_PR_* overridable so -# this can be pointed back at ROCm/ATOM#650 while debugging upstream movement. +# fork commit is ROCm/ATOM#650 head plus AITER sparse_attn/Indexer dispatch. +# Keep ATOM_PR_* overridable so this can be pointed back at ROCm/ATOM#650 while +# debugging upstream movement. ATOM_PR_REPO=${ATOM_PR_REPO:-https://github.com/Oseltamivir/ATOM.git} -ATOM_PR_REF=${ATOM_PR_REF:-dsv4-aiter-sparse-indexer} -ATOM_PR_SHA=${ATOM_PR_SHA:-0ddf9a9a7919631a9a89073d624bd25b16014f17} +ATOM_PR_REF=${ATOM_PR_REF:-dsv4-pr650-head-aiter-sparse} +ATOM_PR_SHA=${ATOM_PR_SHA:-d1a78e61af1a99fc2a156b40d45d011ccb648b5c} export ATOM_PR_DIR="/tmp/atom-pr650" if [ ! -d "$ATOM_PR_DIR/.git" ]; then @@ -265,7 +265,7 @@ from pathlib import Path path = Path("atom/model_ops/sparse_attn_v4.py") source = path.read_text() -marker = "ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS" +marker = "ATOM_DSV4_AITER_SPARSE_ATTN" if marker not in source: source = source.replace( "from typing import Tuple\n\nimport torch\n", @@ -380,7 +380,8 @@ PYEOF # MoE/FFN layer-by-layer. This fixes correctness for CONC>1 and avoids the # worst all-layers-per-request loop until upstream vectorizes the DSv4 # sparse-attention/cache path. - if ! grep -q 'def forward_batched' atom/models/deepseek_v4.py; then + if ! grep -q 'def forward_batched' atom/models/deepseek_v4.py \ + && ! grep -q '_v4_get_seq_metadata' atom/models/deepseek_v4.py; then sed 's/^$/ /' <<'PATCH' | git apply --recount diff --git a/atom/model_engine/llm_engine.py b/atom/model_engine/llm_engine.py index 8de9532..ddde446 100644 diff --git a/perf-changelog.yaml b/perf-changelog.yaml index d77a3cb09..c53422b3d 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -2053,3 +2053,11 @@ - "Add TP=4 DSv4 ATOM sweep points at conc=1/2/4 for 1k1k and 8k1k" - "Profile data shows TP=8 DSv4 ATOM decode/prefill is dominated by aiter cross_device_reduce kernels (~49% of GPU kernel time), so TP=4 tests whether fewer ranks reduces the communication bottleneck enough to improve per-GPU throughput" pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/XXX + +- config-keys: + - dsv4-fp4-mi355x-atom + description: + - "Bump the ATOM DSv4 overlay from Oseltamivir/ATOM@0ddf9a9 to Oseltamivir/ATOM@d1a78e6, rebased on current ROCm/ATOM#650 head a709564" + - "This picks up PR650's CPU metadata path, paged compressed KV plumbing, Triton SWA writes, and compressor state-write plumbing while keeping the AITER sparse_attn/Indexer dispatch" + - "Skip the older local multi-request patch when the newer PR650 _v4_get_seq_metadata path is present" + pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/XXX From df0c152a13ba63e60654833c7c31f2fa41975de0 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Thu, 30 Apr 2026 18:21:27 -0700 Subject: [PATCH 10/14] Retile DSv4 ATOM sparse attention --- .../single_node/dsv4_fp4_mi355x_atom.sh | 29 ++++++++++++++++--- perf-changelog.yaml | 8 +++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh index ec9425871..c8d2ec41e 100644 --- a/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh +++ b/benchmarks/single_node/dsv4_fp4_mi355x_atom.sh @@ -60,8 +60,9 @@ export AITER_LOG_LEVEL=WARNING # * sunway513/aiter@e450e4d adds DSv4 FP4 MoE tuned rows that route # eligible token counts to FlyDSL FP4 MoE kernels instead of default CK # heuristics when the image has the optional flydsl package. -# * Oseltamivir/aiter@083a837 adds DSv4 sparse MQA sink and Indexer -# scorer/top-k Triton ops so ATOM can avoid the PR650 Torch fallback. +# * Oseltamivir/aiter@023eb3b adds DSv4 sparse MQA sink and Indexer +# scorer/top-k Triton ops so ATOM can avoid the PR650 Torch fallback, plus +# a 4x128 sparse-attn tile that reduces repeated QK score work for D=512. # # The open performance PRs cherry-pick cleanly over the pinned main SHA as # of 2026-04-29. @@ -83,7 +84,7 @@ if [ "${AITER_DSV4_PERF_STACK:-1}" = "1" ]; then AITER_DSV4_SPARSE_INDEXER=${AITER_DSV4_SPARSE_INDEXER:-1} AITER_DSV4_SPARSE_INDEXER_REPO=${AITER_DSV4_SPARSE_INDEXER_REPO:-https://github.com/Oseltamivir/aiter.git} AITER_DSV4_SPARSE_INDEXER_REF=${AITER_DSV4_SPARSE_INDEXER_REF:-dsv4-sparse-indexer} - AITER_DSV4_SPARSE_INDEXER_SHA=${AITER_DSV4_SPARSE_INDEXER_SHA:-083a837de5c44080b18b18682f2e7f611717a06b} + AITER_DSV4_SPARSE_INDEXER_SHA=${AITER_DSV4_SPARSE_INDEXER_SHA:-023eb3bc190cd58646517a6c96a3b6b799bc1f40} rm -rf "$AITER_PERF_DIR" git clone --filter=blob:none "$AITER_PERF_REPO" "$AITER_PERF_DIR" @@ -234,7 +235,7 @@ fi # debugging upstream movement. ATOM_PR_REPO=${ATOM_PR_REPO:-https://github.com/Oseltamivir/ATOM.git} ATOM_PR_REF=${ATOM_PR_REF:-dsv4-pr650-head-aiter-sparse} -ATOM_PR_SHA=${ATOM_PR_SHA:-d1a78e61af1a99fc2a156b40d45d011ccb648b5c} +ATOM_PR_SHA=${ATOM_PR_SHA:-486d35fdeeb50c471329c2cd08681df9c3ad53ce} export ATOM_PR_DIR="/tmp/atom-pr650" if [ ! -d "$ATOM_PR_DIR/.git" ]; then @@ -288,6 +289,18 @@ if marker not in source: os.environ.get("ATOM_DSV4_AITER_SPARSE_ATTN_BLOCK_SIZE", "128") or "128" ) + tile_k = int( + os.environ.get("ATOM_DSV4_AITER_SPARSE_ATTN_TILE_K", "64") or "64" + ) + block_h = int( + os.environ.get("ATOM_DSV4_AITER_SPARSE_ATTN_BLOCK_H", "4") or "4" + ) + block_d = int( + os.environ.get("ATOM_DSV4_AITER_SPARSE_ATTN_BLOCK_D", "128") or "128" + ) + score_d = int( + os.environ.get("ATOM_DSV4_AITER_SPARSE_ATTN_SCORE_D", "64") or "64" + ) q_flat = q.reshape(B * M, H, D).contiguous() topk_flat = topk_idxs.reshape(B * M, K).contiguous().int() num_blocks = (N + block_size - 1) // block_size @@ -320,6 +333,10 @@ if marker not in source: topk_flat, block_table, attn_sink.float().contiguous(), + tile_k=tile_k, + block_h=block_h, + block_d=block_d, + score_d=score_d, ) return out.view(B, M, H, D).to(out_dtype) except Exception as exc: @@ -1201,6 +1218,10 @@ BLOCK_SIZE=${BLOCK_SIZE:-16} export ATOM_DSV4_AITER_SPARSE_ATTN=${ATOM_DSV4_AITER_SPARSE_ATTN:-1} export ATOM_DSV4_AITER_INDEXER=${ATOM_DSV4_AITER_INDEXER:-1} export ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS=${ATOM_DSV4_SPARSE_ATTN_CHUNK_TOKENS:-256} +export ATOM_DSV4_AITER_SPARSE_ATTN_TILE_K=${ATOM_DSV4_AITER_SPARSE_ATTN_TILE_K:-64} +export ATOM_DSV4_AITER_SPARSE_ATTN_BLOCK_H=${ATOM_DSV4_AITER_SPARSE_ATTN_BLOCK_H:-4} +export ATOM_DSV4_AITER_SPARSE_ATTN_BLOCK_D=${ATOM_DSV4_AITER_SPARSE_ATTN_BLOCK_D:-128} +export ATOM_DSV4_AITER_SPARSE_ATTN_SCORE_D=${ATOM_DSV4_AITER_SPARSE_ATTN_SCORE_D:-64} # --enforce-eager is required: ROCm/ATOM#650 (PR1 skeleton) has no CUDAGraph # support yet (deferred to a follow-up PR). max-num-seqs is sized to the # client concurrency with a floor at 4 — the ATOM default (512) makes the diff --git a/perf-changelog.yaml b/perf-changelog.yaml index c53422b3d..cebc529e1 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -2061,3 +2061,11 @@ - "This picks up PR650's CPU metadata path, paged compressed KV plumbing, Triton SWA writes, and compressor state-write plumbing while keeping the AITER sparse_attn/Indexer dispatch" - "Skip the older local multi-request patch when the newer PR650 _v4_get_seq_metadata path is present" pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/XXX + +- config-keys: + - dsv4-fp4-mi355x-atom + description: + - "Retile the forked AITER DSv4 sparse MQA sink from 8x64 to 4x128 by default, keeping the same accumulator footprint while reducing repeated QK score work for 512-wide values" + - "Bump sparse/indexer overlay pins to Oseltamivir/aiter@023eb3b and Oseltamivir/ATOM@486d35f" + - "Expose ATOM_DSV4_AITER_SPARSE_ATTN_{TILE_K,BLOCK_H,BLOCK_D,SCORE_D} for profile-guided sparse attention tuning" + pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/XXX From f00e5b7e8004ef1d15d9a788be882edaf4d7776a Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Thu, 30 Apr 2026 22:19:54 -0700 Subject: [PATCH 11/14] eval --- .github/configs/amd-master.yaml | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/.github/configs/amd-master.yaml b/.github/configs/amd-master.yaml index 505ee8843..2bb005620 100644 --- a/.github/configs/amd-master.yaml +++ b/.github/configs/amd-master.yaml @@ -1577,10 +1577,8 @@ dsv4-fp4-mi355x-atom: - isl: 1024 osl: 1024 search-space: - - { tp: 4, ep: 1, conc-start: 1, conc-end: 4 } - - { tp: 8, ep: 1, conc-start: 1, conc-end: 4 } - - isl: 8192 - osl: 1024 - search-space: - - { tp: 4, ep: 1, conc-start: 1, conc-end: 4 } - - { tp: 8, ep: 1, conc-start: 1, conc-end: 4 } + - { tp: 8, ep: 1, conc-start: 1, conc-end: 8 } +# - isl: 8192 +# osl: 1024 +# search-space: +# - { tp: 8, ep: 1, conc-start: 1, conc-end: 4 } From 4b06a002930a410485027de1eb1b1c1031e23eb4 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Thu, 30 Apr 2026 22:28:08 -0700 Subject: [PATCH 12/14] Add DSv4 ATOM eval-only point --- .github/configs/amd-master.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/configs/amd-master.yaml b/.github/configs/amd-master.yaml index 2bb005620..193a20576 100644 --- a/.github/configs/amd-master.yaml +++ b/.github/configs/amd-master.yaml @@ -1578,7 +1578,7 @@ dsv4-fp4-mi355x-atom: osl: 1024 search-space: - { tp: 8, ep: 1, conc-start: 1, conc-end: 8 } -# - isl: 8192 -# osl: 1024 -# search-space: -# - { tp: 8, ep: 1, conc-start: 1, conc-end: 4 } + - isl: 8192 + osl: 1024 + search-space: + - { tp: 8, ep: 1, conc-start: 16, conc-end: 16 } From 5128a68b03e21767efb68b99636b06e1b34affef Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Thu, 30 Apr 2026 23:19:00 -0700 Subject: [PATCH 13/14] Fix DSv4 eval prompt encoding --- benchmarks/benchmark_lib.sh | 85 +++++++++++++++++++++++++++++++++---- 1 file changed, 77 insertions(+), 8 deletions(-) diff --git a/benchmarks/benchmark_lib.sh b/benchmarks/benchmark_lib.sh index 41fc891b3..6c707fc90 100644 --- a/benchmarks/benchmark_lib.sh +++ b/benchmarks/benchmark_lib.sh @@ -554,7 +554,7 @@ _patch_lm_eval() { patch_dir="$(mktemp -d)" cat > "$patch_dir/sitecustomize.py" <<'PY' # --- Patch LocalChatCompletion.parse_generations to handle empty content with reasoning_content --- -import re, sys, unicodedata, json +import os, re, sys, unicodedata, json from lm_eval.filters import extraction as ex from lm_eval.models.openai_completions import LocalChatCompletion as _LCC @@ -581,7 +581,7 @@ def _le_parse_generations(outputs, **kwargs): # Keep staticmethod semantics _LCC.parse_generations = staticmethod(_le_parse_generations) -# --- Patch TemplateAPI.apply_chat_template to avoid injecting "type": "text" for TRT --- +# --- Patch TemplateAPI.apply_chat_template --- try: from lm_eval.models import api_models as _api_models _TemplateAPI = _api_models.TemplateAPI @@ -592,6 +592,56 @@ except Exception: if _TemplateAPI is not None and _JsonChatStr is not None: _orig_apply_chat_template = _TemplateAPI.apply_chat_template + _dsv4_encode_messages = None + + def _content_to_text(content): + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, dict): + parts.append(str(item.get("text", item.get("content", "")))) + else: + parts.append(str(item)) + return "\n".join(part for part in parts if part) + if content is None: + return "" + return str(content) + + def _load_dsv4_encoder(): + global _dsv4_encode_messages + if _dsv4_encode_messages is not None: + return _dsv4_encode_messages + + roots = [ + os.environ.get("INFMAX_WORKSPACE"), + os.environ.get("GITHUB_WORKSPACE"), + os.getcwd(), + "/workspace", + "/infmax-workspace", + ] + for root in roots: + if not root: + continue + candidate = os.path.join(root, "utils", "bench_serving") + if os.path.exists(os.path.join(candidate, "encoding_dsv4.py")) and candidate not in sys.path: + sys.path.insert(0, candidate) + + from encoding_dsv4 import encode_messages + + _dsv4_encode_messages = encode_messages + return _dsv4_encode_messages + + def _apply_dsv4_chat_template(chat_history): + encode_messages = _load_dsv4_encoder() + messages = [] + for item in chat_history: + normalized = {**item} + normalized.pop("type", None) + normalized["content"] = _content_to_text(normalized.get("content")) + messages.append(normalized) + return encode_messages(messages, thinking_mode="thinking") def _patched_apply_chat_template( self, @@ -599,6 +649,8 @@ if _TemplateAPI is not None and _JsonChatStr is not None: add_generation_prompt: bool = True, ): """Applies a chat template to a list of chat history between user and model.""" + if os.environ.get("EVAL_DSV4_CHAT_TEMPLATE") == "1": + return _apply_dsv4_chat_template(chat_history) if self.tokenizer_backend == "huggingface" and self.tokenized_requests: return self.tokenizer.apply_chat_template( chat_history, @@ -703,13 +755,30 @@ run_lm_eval() { esac done - _install_lm_eval_deps - _patch_lm_eval - local openai_server_base="http://0.0.0.0:${port}" local openai_chat_base="${openai_server_base}/v1/chat/completions" + local openai_completions_base="${openai_server_base}/v1/completions" export OPENAI_API_KEY=${OPENAI_API_KEY:-EMPTY} - MODEL_NAME=${MODEL_NAME:-$MODEL} # Prefer MODEL_NAME, else MODEL + export MODEL_NAME="${MODEL_NAME:-$MODEL}" # Prefer MODEL_NAME, else MODEL + + local lm_eval_model="local-chat-completions" + local lm_eval_base_url="$openai_chat_base" + local lm_eval_eos_string="${EVAL_EOS_STRING:-}" + local lm_eval_tokenizer_args="tokenized_requests=False" + + if [[ "${MODEL_PREFIX:-}" == "dsv4" || "${MODEL_NAME:-}" == *"DeepSeek-V4"* || "${MODEL:-}" == *"DeepSeek-V4"* ]]; then + export EVAL_DSV4_CHAT_TEMPLATE=1 + lm_eval_model="local-completions" + lm_eval_base_url="$openai_completions_base" + lm_eval_eos_string="${EVAL_EOS_STRING:-<|end▁of▁sentence|>}" + lm_eval_tokenizer_args="tokenizer_backend=None,tokenized_requests=False" + echo "Using DeepSeek-V4 eval prompt encoding via utils/bench_serving/encoding_dsv4.py" + else + unset EVAL_DSV4_CHAT_TEMPLATE + fi + + _install_lm_eval_deps + _patch_lm_eval # Cap output tokens: must fit within context window (leave room for input), # and avoid excessive KV cache reservation per request on TRT. @@ -722,11 +791,11 @@ run_lm_eval() { # Export for append_lm_eval_summary to pick up export EVAL_RESULT_DIR="$results_dir" set -x - python3 -m lm_eval --model local-chat-completions --apply_chat_template \ + python3 -m lm_eval --model "${lm_eval_model}" --apply_chat_template \ --tasks "${tasks_dir}" \ --output_path "${results_dir}" \ --log_samples \ - --model_args "model=${MODEL_NAME},base_url=${openai_chat_base},api_key=${OPENAI_API_KEY},eos_string=,max_retries=5,num_concurrent=${concurrent_requests},timeout=1800,tokenized_requests=False,max_length=${eval_context_len}" \ + --model_args "model=${MODEL_NAME},base_url=${lm_eval_base_url},api_key=${OPENAI_API_KEY},eos_string=${lm_eval_eos_string},max_retries=5,num_concurrent=${concurrent_requests},timeout=1800,${lm_eval_tokenizer_args},max_length=${eval_context_len}" \ --gen_kwargs "max_tokens=${max_output_tokens},temperature=${temperature},top_p=${top_p}" local eval_exit=$? set +x From 8d49a42070d9109c452aa379ca73f345a0009a79 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Fri, 1 May 2026 00:24:38 -0700 Subject: [PATCH 14/14] Limit DSv4 ATOM eval smoke run --- .github/configs/amd-master.yaml | 2 +- benchmarks/benchmark_lib.sh | 11 ++++++++++- utils/matrix_logic/generate_sweep_configs.py | 2 +- utils/matrix_logic/test_generate_sweep_configs.py | 7 +++---- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/.github/configs/amd-master.yaml b/.github/configs/amd-master.yaml index 193a20576..7c1b4d99e 100644 --- a/.github/configs/amd-master.yaml +++ b/.github/configs/amd-master.yaml @@ -1581,4 +1581,4 @@ dsv4-fp4-mi355x-atom: - isl: 8192 osl: 1024 search-space: - - { tp: 8, ep: 1, conc-start: 16, conc-end: 16 } + - { tp: 8, ep: 1, conc-start: 4, conc-end: 4 } diff --git a/benchmarks/benchmark_lib.sh b/benchmarks/benchmark_lib.sh index 6c707fc90..600b9aea9 100644 --- a/benchmarks/benchmark_lib.sh +++ b/benchmarks/benchmark_lib.sh @@ -741,7 +741,8 @@ run_lm_eval() { local eval_context_len="${EVAL_MAX_MODEL_LEN:-16384}" local temperature=0 local top_p=1 - local concurrent_requests="${EVAL_CONCURRENT_REQUESTS:-64}" + local concurrent_requests="${EVAL_CONCURRENT_REQUESTS:-${CONC:-64}}" + local eval_limit="${EVAL_LIMIT:-}" while [[ $# -gt 0 ]]; do case $1 in @@ -751,6 +752,7 @@ run_lm_eval() { --gen-max-tokens) eval_context_len="$2"; shift 2 ;; --temperature) temperature="$2"; shift 2 ;; --top-p) top_p="$2"; shift 2 ;; + --limit) eval_limit="$2"; shift 2 ;; *) echo "Unknown parameter: $1"; return 1 ;; esac done @@ -772,6 +774,7 @@ run_lm_eval() { lm_eval_base_url="$openai_completions_base" lm_eval_eos_string="${EVAL_EOS_STRING:-<|end▁of▁sentence|>}" lm_eval_tokenizer_args="tokenizer_backend=None,tokenized_requests=False" + eval_limit="${eval_limit:-40}" echo "Using DeepSeek-V4 eval prompt encoding via utils/bench_serving/encoding_dsv4.py" else unset EVAL_DSV4_CHAT_TEMPLATE @@ -790,11 +793,17 @@ run_lm_eval() { # Export for append_lm_eval_summary to pick up export EVAL_RESULT_DIR="$results_dir" + local limit_args=() + if [ -n "$eval_limit" ]; then + limit_args=(--limit "$eval_limit") + echo "Eval sample limit: ${eval_limit}" + fi set -x python3 -m lm_eval --model "${lm_eval_model}" --apply_chat_template \ --tasks "${tasks_dir}" \ --output_path "${results_dir}" \ --log_samples \ + "${limit_args[@]}" \ --model_args "model=${MODEL_NAME},base_url=${lm_eval_base_url},api_key=${OPENAI_API_KEY},eos_string=${lm_eval_eos_string},max_retries=5,num_concurrent=${concurrent_requests},timeout=1800,${lm_eval_tokenizer_args},max_length=${eval_context_len}" \ --gen_kwargs "max_tokens=${max_output_tokens},temperature=${temperature},top_p=${top_p}" local eval_exit=$? diff --git a/utils/matrix_logic/generate_sweep_configs.py b/utils/matrix_logic/generate_sweep_configs.py index aeebcfa1f..2a90c679a 100644 --- a/utils/matrix_logic/generate_sweep_configs.py +++ b/utils/matrix_logic/generate_sweep_configs.py @@ -19,7 +19,7 @@ "8k1k": (8192, 1024) } -MIN_EVAL_CONC = 16 +MIN_EVAL_CONC = 4 # Reverse mapping for exp-name generation seq_len_itos = {v: k for k, v in seq_len_stoi.items()} diff --git a/utils/matrix_logic/test_generate_sweep_configs.py b/utils/matrix_logic/test_generate_sweep_configs.py index 34bd4dc3d..fe4bd3033 100644 --- a/utils/matrix_logic/test_generate_sweep_configs.py +++ b/utils/matrix_logic/test_generate_sweep_configs.py @@ -305,7 +305,7 @@ def test_multi_node_eval_conc_uses_only_conc_values_at_or_above_min_conc(self): "ep": 1, "dp-attn": False, }, - "conc": [8, 16, 32], + "conc": [MIN_EVAL_CONC // 2, MIN_EVAL_CONC, MIN_EVAL_CONC * 2], }, { "model": "deepseek-ai/DeepSeek-R1-0528", @@ -327,14 +327,14 @@ def test_multi_node_eval_conc_uses_only_conc_values_at_or_above_min_conc(self): "ep": 1, "dp-attn": False, }, - "conc": [8], + "conc": [MIN_EVAL_CONC // 2], }, ] result = mark_eval_entries(matrix_values) assert result[0]["run-eval"] is True - assert result[0]["eval-conc"] == 32 + assert result[0]["eval-conc"] == MIN_EVAL_CONC * 2 assert result[1]["run-eval"] is False def test_marks_highest_and_median_conc(self): @@ -1928,4 +1928,3 @@ def test_prefill_entries_never_in_single_or_evals(self, mixed_entries): assert all('prefill' in x for x in multi) assert all('prefill' not in x for x in single) assert all('prefill' not in x for x in evals) -