feat: add Step-3.5-Flash support and fix MoE weight shuffling on gfx950#641
Draft
LJ-underdog wants to merge 8 commits intomainfrom
Draft
feat: add Step-3.5-Flash support and fix MoE weight shuffling on gfx950#641LJ-underdog wants to merge 8 commits intomainfrom
LJ-underdog wants to merge 8 commits intomainfrom
Conversation
Add Step3p5ForCausalLM model support for the Step-3.5-Flash architecture,
and fix a critical MoE correctness bug on gfx950 (MI350X).
Core MoE fix (atom/model_ops/moe.py):
Previously skipped shuffle_weights() for gfx950 BF16 g1u1 based on the
incorrect assumption that the CK 2-stage preshuffle_off (NSwizzle=0)
kernel expects un-shuffled weights. Verified: preshuffle_off GEMM is
wrong on gfx950; preshuffle_on (NSwizzle=1) is correct. Always call
shuffle_weights() so the correct kernel path is selected.
Step-3.5-Flash model support (atom/models/step3p5.py):
- Mixed full/sliding window attention (per layer_types config)
- 288 routed + 1 shared expert MoE with sigmoid routing
- Per-layer SwigluStep activation: layers with swiglu_limits[i]>0 use
ActivationType.SwigluStep (CK kernel applies silu(g).clamp(7)*up.clamp(±7));
other layers use plain Silu. Shared expert at SwigluStep layers is kept
on the dense MLP path (kernel clamp is routed-expert-only).
- Fused expert loading (flat [E,I,H] checkpoint format)
- clamp_limit applied to dense MLP and shared expert via Step3p5MLP
atom/model_engine/model_runner.py:
- Register Step3p5ForCausalLM architecture
- Handle num_attention_groups config key (Step-3.5 uses this instead of
num_key_value_heads) in KV head count calculations
atom/model_loader/loader.py:
- Fix fused expert detection order: check before packed_modules_mapping
to prevent moe.gate_proj being matched as gate_up_proj
atom/model_ops/attentions/aiter_attention.py:
- Handle num_attention_groups config key for KV head count
atom/examples/simple_inference.py:
- Add --max-tokens arg and trust_remote_code support
Verified: tp=2 Step-3.5-Flash inference, 4 prompts, no NaN/crash,
coherent output (with ATOM_STEP3P5_NO_SLIDING=1 workaround for
pa_decode_gluon bug on gfx950, tracked separately).
Co-Authored-By: Jun Lin <junlin12@amd.com>
| """ | ||
|
|
||
| import os | ||
| from typing import Any, Optional, Union |
Contributor
| print(f"[NAN-ATTN] attn_output.shape={attn_output.shape} reshaped.shape={reshaped.shape} gate.shape={gate.shape}") | ||
| attn_output = (reshaped * gate).flatten(-2) | ||
| if debug_nan and attn_output.isnan().any(): | ||
| print(f"[NAN-ATTN] after gate multiply has NaN") |
Contributor
|
|
||
| output = self.o_proj(attn_output) | ||
| if debug_nan and output.isnan().any(): | ||
| print(f"[NAN-ATTN] o_proj output has NaN") |
Contributor
| # Per-layer SwiGLU clamp limits | ||
| swiglu_limits = getattr(config, "swiglu_limits", None) | ||
| swiglu_limits_shared = getattr(config, "swiglu_limits_shared", None) | ||
| clamp_limit = swiglu_limits[layer_idx] if swiglu_limits else None |
Contributor
CK 2-stage MoE kernel (gemm_moe_ck2stages.cu L98) computes stage1 N as w1.size(1)/2 = inter_dim. The stage1 dispatch selects NPerBlock based on inter_dim range: - inter <= 192: NPerBlock = 64 -> need inter % 64 == 0 - inter > 192: NPerBlock = 128 -> need inter % 128 == 0 Step-3.5-Flash with tp=4 gives inter=320 (320%128=64 != 0, crash) and with tp=8 gives inter=160 (160%64=32 != 0, crash). Fix: in process_weights_after_loading, pad inter_dim before shuffle_weights() using alignment = 64 if inter<=192 else 128: - inter=160 -> 192 (tp=8, 192%64=0) - inter=320 -> 384 (tp=4, 384%128=0, 384%64=0) Zero-padding is safe: padded rows carry zero weight so contribute nothing to fused_moe output. Verified 2026-04-24 on gfx950 (MI350X): - cos_sim >= 0.9999 vs torch reference (M=1..256) - tp=4 inference: 4 prompts complete, no crash, output correct Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The else branch in get_fused_moe_quant_config was shared between block_quant (per_1x128/per_1x32) and per_tensor paths, hardcoding block_shape=None for all. Block-quantized FP8 models should receive block_shape=[128,128] (per_1x128) or [1,32] (per_1x32) to correctly configure the quant config, particularly for EP paths. Split the else branch into explicit per_1x128/per_1x32/fallback cases and unify the fp8_w8a8_moe_quant_config call.
d8caf2e to
841dc4e
Compare
Three coordinated fixes in Fp8MoEMethod for per_1x128 block scale: 1. create_weights: make ValueError check padding-aware Compute padded_inter = ceil(inter/block_n)*block_n and check against padded_inter instead of raw inter, allowing tp=4 (inter=320) to pass while preserving the guard for truly unaligned cases. 2. _process_block_quant: zero-pad weights before shuffle_weights After normalize and before shuffle, zero-pad w13 from [E,2*320,H] to [E,2*384,H] and w2 from [E,H,320] to [E,H,384], mirroring the BF16 approach in UnquantizedFusedMoEMethod.process_weights_after_loading. Padding zeros contribute 0 to GEMM output (dequant(0, scale)=0). Scale tensors already use ceil(inter/block_n) and need no change. 3. _load_w13 / _load_w2: fix scale TP sharding floor→ceil (root cause) The per_1x128 scale for full inter=1280 has 10 N-blocks. TP=4 sharding with floor gives 10//4=2 blocks per rank; the 3rd (partial) block is never copied and stays at the torch.ones() init value of 1.0. With scale=1.0 instead of ~0.0002, dequant amplifies by ~5000× causing complete garbage output despite correct weight loading. Fix: use ceil division and add narrow() bounds protection for the last rank which may have fewer elements than the ceil size. Safe for tp=2 (10/2=5 exact, ceil==floor) and tp=1 (no sharding). Verification: FP8 tp=4: 4 prompts, TTFT=92ms, TPOT=14ms, coherent output ✅ BF16 tp=4 regression: TTFT=76-77ms, coherent output ✅ FP8 tp=2 regression: TTFT=86ms, coherent output ✅
…ding With NPerBlock=64 CK kernel support, inter_dim=320 (tp=4) is 64-aligned and no longer requires zero-padding to 384. Changed align from '64 if inter<=192 else block_n' to always 64, so: - tp=4 (inter=320): 320%64=0 -> no padding (was 320->384, saved 17% compute) - tp=8 (inter=160): 160%64=32 -> pad to 192 (unchanged) - tp=2 (inter=640): 640%64=0 -> no padding (unchanged) Scale tensor shape (ceil(320/128)=3) unchanged; no re-quantization needed. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Stage2 KPerBlock=64 is not compilable on gfx950 (FP8 mfma KPack=32 constraint). Since stage1 output and stage2 weight K must match, both w13 and w2 require the same inter_dim padding. Restoring: align = 64 if inter_dim <= 192 else block_n (=128) Added comment explaining why full no-padding is currently blocked. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…onfigs _process_block_quant used 'align = 64 if inter_dim <= 192 else block_n', copied from the BF16 path. For FP8 blockscale this is wrong: - FP8 stage2 only has KPerBlock=128 (KPack=32 mfma constraint prevents KPerBlock=64) - align=64 gives inter_pad=192 for tp=8 (inter=160), but 192 % 128 = 64 != 0 - device_moe_gemm_blockscale.hpp L448 rejects K % KPerBlock != 0 → kernel fails Fix: always use align = block_n (=128 for per_1x128), so inter_pad is always a multiple of 128 and stage2 KPerBlock=128 dispatch succeeds: tp=2: inter=640 → 640 (no padding, unchanged) tp=4: inter=320 → 384 (unchanged) tp=8: inter=160 → 256 (was 192, now correctly aligned) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When the per_1x128 scale block count is smaller than tp_size (observed on Step-3.5-Flash-FP8 at tp=8 with inter_dim=1280 → D=10), the ceil split leaves trailing ranks with start >= D so narrow(start, size) hits size<0 and crashes weight load. Skip narrow + copy_ for those ranks. For fp8 scale tensors (torch.ones() initialised in Fp8MoEMethod._create_weights), additionally zero the rank's slot before the early return. Otherwise the downstream fp8 dequant multiplies the (uninitialised) fp8 weight by stale 1.0 instead of the correct quantization scale, contaminating the column gather / row reduction and producing garbled output. Matches MXFP4 scale init (moe.py:776,813). Verified on stepfun-ai/Step-3.5-Flash-FP8 (gfx942 / MI308X): - tp=8 A1/A2/A4 PASS — 4/4 prompts coherent (was: weight-load crash pre-patch; was: garbled output with early-return-only) - tp=2/tp=4 A1/A2/A3 PASS — no regression, zero-trigger confirmed (D=10, starts=[0,3,6,9] for tp=4, starts=[0,5] for tp=2 — all < D) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add Step3p5ForCausalLM model support for the Step-3.5-Flash architecture, and fix a critical MoE correctness bug on gfx950 (MI350X).
Core MoE fix (atom/model_ops/moe.py):
Previously skipped shuffle_weights() for gfx950 BF16 g1u1 based on the
incorrect assumption that the CK 2-stage preshuffle_off (NSwizzle=0)
kernel expects un-shuffled weights. Verified: preshuffle_off GEMM is
wrong on gfx950; preshuffle_on (NSwizzle=1) is correct. Always call
shuffle_weights() so the correct kernel path is selected.
Step-3.5-Flash model support (atom/models/step3p5.py):
atom/model_engine/model_runner.py:
atom/model_loader/loader.py:
atom/model_ops/attentions/aiter_attention.py:
atom/examples/simple_inference.py:
Verified: tp=2 Step-3.5-Flash inference, 4 prompts, no NaN/crash, coherent output (with ATOM_STEP3P5_NO_SLIDING=1 workaround for pa_decode_gluon bug on gfx950, tracked separately).
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist