Skip to content

feat: add Step-3.5-Flash support and fix MoE weight shuffling on gfx950#641

Draft
LJ-underdog wants to merge 8 commits intomainfrom
feat/step3p5-flash-support
Draft

feat: add Step-3.5-Flash support and fix MoE weight shuffling on gfx950#641
LJ-underdog wants to merge 8 commits intomainfrom
feat/step3p5-flash-support

Conversation

@LJ-underdog
Copy link
Copy Markdown

Add Step3p5ForCausalLM model support for the Step-3.5-Flash architecture, and fix a critical MoE correctness bug on gfx950 (MI350X).

Core MoE fix (atom/model_ops/moe.py):
Previously skipped shuffle_weights() for gfx950 BF16 g1u1 based on the
incorrect assumption that the CK 2-stage preshuffle_off (NSwizzle=0)
kernel expects un-shuffled weights. Verified: preshuffle_off GEMM is
wrong on gfx950; preshuffle_on (NSwizzle=1) is correct. Always call
shuffle_weights() so the correct kernel path is selected.

Step-3.5-Flash model support (atom/models/step3p5.py):

  • Mixed full/sliding window attention (per layer_types config)
  • 288 routed + 1 shared expert MoE with sigmoid routing
  • Per-layer SwigluStep activation: layers with swiglu_limits[i]>0 use ActivationType.SwigluStep (CK kernel applies silu(g).clamp(7)*up.clamp(±7)); other layers use plain Silu. Shared expert at SwigluStep layers is kept on the dense MLP path (kernel clamp is routed-expert-only).
  • Fused expert loading (flat [E,I,H] checkpoint format)
  • clamp_limit applied to dense MLP and shared expert via Step3p5MLP

atom/model_engine/model_runner.py:

  • Register Step3p5ForCausalLM architecture
  • Handle num_attention_groups config key (Step-3.5 uses this instead of num_key_value_heads) in KV head count calculations

atom/model_loader/loader.py:

  • Fix fused expert detection order: check before packed_modules_mapping to prevent moe.gate_proj being matched as gate_up_proj

atom/model_ops/attentions/aiter_attention.py:

  • Handle num_attention_groups config key for KV head count

atom/examples/simple_inference.py:

  • Add --max-tokens arg and trust_remote_code support

Verified: tp=2 Step-3.5-Flash inference, 4 prompts, no NaN/crash, coherent output (with ATOM_STEP3P5_NO_SLIDING=1 workaround for pa_decode_gluon bug on gfx950, tracked separately).

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

Add Step3p5ForCausalLM model support for the Step-3.5-Flash architecture,
and fix a critical MoE correctness bug on gfx950 (MI350X).

Core MoE fix (atom/model_ops/moe.py):
  Previously skipped shuffle_weights() for gfx950 BF16 g1u1 based on the
  incorrect assumption that the CK 2-stage preshuffle_off (NSwizzle=0)
  kernel expects un-shuffled weights. Verified: preshuffle_off GEMM is
  wrong on gfx950; preshuffle_on (NSwizzle=1) is correct. Always call
  shuffle_weights() so the correct kernel path is selected.

Step-3.5-Flash model support (atom/models/step3p5.py):
  - Mixed full/sliding window attention (per layer_types config)
  - 288 routed + 1 shared expert MoE with sigmoid routing
  - Per-layer SwigluStep activation: layers with swiglu_limits[i]>0 use
    ActivationType.SwigluStep (CK kernel applies silu(g).clamp(7)*up.clamp(±7));
    other layers use plain Silu. Shared expert at SwigluStep layers is kept
    on the dense MLP path (kernel clamp is routed-expert-only).
  - Fused expert loading (flat [E,I,H] checkpoint format)
  - clamp_limit applied to dense MLP and shared expert via Step3p5MLP

atom/model_engine/model_runner.py:
  - Register Step3p5ForCausalLM architecture
  - Handle num_attention_groups config key (Step-3.5 uses this instead of
    num_key_value_heads) in KV head count calculations

atom/model_loader/loader.py:
  - Fix fused expert detection order: check before packed_modules_mapping
    to prevent moe.gate_proj being matched as gate_up_proj

atom/model_ops/attentions/aiter_attention.py:
  - Handle num_attention_groups config key for KV head count

atom/examples/simple_inference.py:
  - Add --max-tokens arg and trust_remote_code support

Verified: tp=2 Step-3.5-Flash inference, 4 prompts, no NaN/crash,
coherent output (with ATOM_STEP3P5_NO_SLIDING=1 workaround for
pa_decode_gluon bug on gfx950, tracked separately).

Co-Authored-By: Jun Lin <junlin12@amd.com>
Comment thread atom/models/step3p5.py
"""

import os
from typing import Any, Optional, Union
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F401> reported by reviewdog 🐶
typing.Any imported but unused

Suggested change
from typing import Any, Optional, Union
from typing import Optional, Union

Comment thread atom/models/step3p5.py
print(f"[NAN-ATTN] attn_output.shape={attn_output.shape} reshaped.shape={reshaped.shape} gate.shape={gate.shape}")
attn_output = (reshaped * gate).flatten(-2)
if debug_nan and attn_output.isnan().any():
print(f"[NAN-ATTN] after gate multiply has NaN")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F541> reported by reviewdog 🐶
f-string without any placeholders

Suggested change
print(f"[NAN-ATTN] after gate multiply has NaN")
print("[NAN-ATTN] after gate multiply has NaN")

Comment thread atom/models/step3p5.py

output = self.o_proj(attn_output)
if debug_nan and output.isnan().any():
print(f"[NAN-ATTN] o_proj output has NaN")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F541> reported by reviewdog 🐶
f-string without any placeholders

Suggested change
print(f"[NAN-ATTN] o_proj output has NaN")
print("[NAN-ATTN] o_proj output has NaN")

Comment thread atom/models/step3p5.py
# Per-layer SwiGLU clamp limits
swiglu_limits = getattr(config, "swiglu_limits", None)
swiglu_limits_shared = getattr(config, "swiglu_limits_shared", None)
clamp_limit = swiglu_limits[layer_idx] if swiglu_limits else None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F841> reported by reviewdog 🐶
Local variable clamp_limit is assigned to but never used

Suggested change
clamp_limit = swiglu_limits[layer_idx] if swiglu_limits else None
swiglu_limits[layer_idx] if swiglu_limits else None

LJ-underdog and others added 2 commits April 24, 2026 21:54
CK 2-stage MoE kernel (gemm_moe_ck2stages.cu L98) computes stage1 N
as w1.size(1)/2 = inter_dim. The stage1 dispatch selects NPerBlock
based on inter_dim range:
  - inter <= 192: NPerBlock = 64  -> need inter % 64 == 0
  - inter >  192: NPerBlock = 128 -> need inter % 128 == 0

Step-3.5-Flash with tp=4 gives inter=320 (320%128=64 != 0, crash)
and with tp=8 gives inter=160 (160%64=32 != 0, crash).

Fix: in process_weights_after_loading, pad inter_dim before
shuffle_weights() using alignment = 64 if inter<=192 else 128:
  - inter=160 -> 192  (tp=8, 192%64=0)
  - inter=320 -> 384  (tp=4, 384%128=0, 384%64=0)

Zero-padding is safe: padded rows carry zero weight so contribute
nothing to fused_moe output.

Verified 2026-04-24 on gfx950 (MI350X):
  - cos_sim >= 0.9999 vs torch reference (M=1..256)
  - tp=4 inference: 4 prompts complete, no crash, output correct

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The else branch in get_fused_moe_quant_config was shared between
block_quant (per_1x128/per_1x32) and per_tensor paths, hardcoding
block_shape=None for all. Block-quantized FP8 models should receive
block_shape=[128,128] (per_1x128) or [1,32] (per_1x32) to correctly
configure the quant config, particularly for EP paths.

Split the else branch into explicit per_1x128/per_1x32/fallback cases
and unify the fp8_w8a8_moe_quant_config call.
@LJ-underdog LJ-underdog force-pushed the feat/step3p5-flash-support branch from d8caf2e to 841dc4e Compare April 24, 2026 21:54
LJ-underdog and others added 5 commits April 25, 2026 01:33
Three coordinated fixes in Fp8MoEMethod for per_1x128 block scale:

1. create_weights: make ValueError check padding-aware
   Compute padded_inter = ceil(inter/block_n)*block_n and check against
   padded_inter instead of raw inter, allowing tp=4 (inter=320) to pass
   while preserving the guard for truly unaligned cases.

2. _process_block_quant: zero-pad weights before shuffle_weights
   After normalize and before shuffle, zero-pad w13 from [E,2*320,H] to
   [E,2*384,H] and w2 from [E,H,320] to [E,H,384], mirroring the BF16
   approach in UnquantizedFusedMoEMethod.process_weights_after_loading.
   Padding zeros contribute 0 to GEMM output (dequant(0, scale)=0).
   Scale tensors already use ceil(inter/block_n) and need no change.

3. _load_w13 / _load_w2: fix scale TP sharding floor→ceil (root cause)
   The per_1x128 scale for full inter=1280 has 10 N-blocks. TP=4 sharding
   with floor gives 10//4=2 blocks per rank; the 3rd (partial) block is
   never copied and stays at the torch.ones() init value of 1.0.
   With scale=1.0 instead of ~0.0002, dequant amplifies by ~5000×
   causing complete garbage output despite correct weight loading.
   Fix: use ceil division and add narrow() bounds protection for the
   last rank which may have fewer elements than the ceil size.
   Safe for tp=2 (10/2=5 exact, ceil==floor) and tp=1 (no sharding).

Verification:
  FP8 tp=4: 4 prompts, TTFT=92ms, TPOT=14ms, coherent output ✅
  BF16 tp=4 regression: TTFT=76-77ms, coherent output ✅
  FP8 tp=2 regression: TTFT=86ms, coherent output ✅
…ding

With NPerBlock=64 CK kernel support, inter_dim=320 (tp=4) is 64-aligned
and no longer requires zero-padding to 384. Changed align from
'64 if inter<=192 else block_n' to always 64, so:
- tp=4 (inter=320): 320%64=0 -> no padding (was 320->384, saved 17% compute)
- tp=8 (inter=160): 160%64=32 -> pad to 192 (unchanged)
- tp=2 (inter=640): 640%64=0 -> no padding (unchanged)

Scale tensor shape (ceil(320/128)=3) unchanged; no re-quantization needed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Stage2 KPerBlock=64 is not compilable on gfx950 (FP8 mfma KPack=32
constraint). Since stage1 output and stage2 weight K must match,
both w13 and w2 require the same inter_dim padding. Restoring:
  align = 64 if inter_dim <= 192 else block_n (=128)

Added comment explaining why full no-padding is currently blocked.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…onfigs

_process_block_quant used 'align = 64 if inter_dim <= 192 else block_n',
copied from the BF16 path. For FP8 blockscale this is wrong:
- FP8 stage2 only has KPerBlock=128 (KPack=32 mfma constraint prevents KPerBlock=64)
- align=64 gives inter_pad=192 for tp=8 (inter=160), but 192 % 128 = 64 != 0
- device_moe_gemm_blockscale.hpp L448 rejects K % KPerBlock != 0 → kernel fails

Fix: always use align = block_n (=128 for per_1x128), so inter_pad is always
a multiple of 128 and stage2 KPerBlock=128 dispatch succeeds:
  tp=2: inter=640 → 640 (no padding, unchanged)
  tp=4: inter=320 → 384 (unchanged)
  tp=8: inter=160 → 256 (was 192, now correctly aligned)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When the per_1x128 scale block count is smaller than tp_size (observed
on Step-3.5-Flash-FP8 at tp=8 with inter_dim=1280 → D=10), the ceil
split leaves trailing ranks with start >= D so narrow(start, size) hits
size<0 and crashes weight load. Skip narrow + copy_ for those ranks.

For fp8 scale tensors (torch.ones() initialised in
Fp8MoEMethod._create_weights), additionally zero the rank's slot before
the early return. Otherwise the downstream fp8 dequant multiplies the
(uninitialised) fp8 weight by stale 1.0 instead of the correct
quantization scale, contaminating the column gather / row reduction and
producing garbled output. Matches MXFP4 scale init (moe.py:776,813).

Verified on stepfun-ai/Step-3.5-Flash-FP8 (gfx942 / MI308X):
- tp=8 A1/A2/A4 PASS — 4/4 prompts coherent (was: weight-load crash
  pre-patch; was: garbled output with early-return-only)
- tp=2/tp=4 A1/A2/A3 PASS — no regression, zero-trigger confirmed
  (D=10, starts=[0,3,6,9] for tp=4, starts=[0,5] for tp=2 — all < D)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant