diff --git a/atom/examples/simple_inference.py b/atom/examples/simple_inference.py index e4b37173c..d611fc9f2 100644 --- a/atom/examples/simple_inference.py +++ b/atom/examples/simple_inference.py @@ -19,6 +19,9 @@ parser.add_argument( "--temperature", type=float, default=0.6, help="temperature for sampling" ) +parser.add_argument( + "--max-tokens", type=int, default=256, help="max tokens to generate" +) def generate_cuda_graph_sizes(max_size): @@ -46,9 +49,11 @@ def main(): engine_args = EngineArgs.from_cli_args(args) llm = engine_args.create_engine() - tokenizer = AutoTokenizer.from_pretrained(args.model) + tokenizer = AutoTokenizer.from_pretrained( + args.model, trust_remote_code=getattr(args, "trust_remote_code", False) + ) - sampling_params = SamplingParams(temperature=args.temperature, max_tokens=256) + sampling_params = SamplingParams(temperature=args.temperature, max_tokens=args.max_tokens) prompts = [ tokenizer.apply_chat_template( @@ -60,9 +65,6 @@ def main(): for prompt in prompts ] print("This is prompts:", prompts) - # print("Warming up...") - # _ = llm.generate(["warmup"], sampling_params) - # print("Warm up done") print("\n" + "=" * 70) print("Starting profiling...") diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index bbd2d061c..1869d1b6f 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -66,6 +66,7 @@ "Qwen3_5MoeForConditionalGeneration": "atom.models.qwen3_5.Qwen3_5MoeForConditionalGenerationTextOnly", "KimiK25ForConditionalGeneration": "atom.models.kimi_k25.KimiK25ForCausalLM", "MiniMaxM2ForCausalLM": "atom.models.minimax_m2.MiniMaxM2ForCausalLM", + "Step3p5ForCausalLM": "atom.models.step3p5.Step3p5ForCausalLM", "MiMoV2FlashForCausalLM": "atom.models.mimo_v2_flash.MiMoV2FlashForCausalLM", } # seed = 34567 @@ -1027,11 +1028,15 @@ def allocate_forward_vars(self): def _get_num_kv_heads(self): """Return the per-rank number of KV heads.""" hf_config = self.config.hf_config - if hf_config.num_key_value_heads >= self.world_size: - assert hf_config.num_key_value_heads % self.world_size == 0 - return hf_config.num_key_value_heads // self.world_size + num_kv_heads_cfg = getattr( + hf_config, "num_key_value_heads", + getattr(hf_config, "num_attention_groups", None) + ) + if num_kv_heads_cfg >= self.world_size: + assert num_kv_heads_cfg % self.world_size == 0 + return num_kv_heads_cfg // self.world_size else: - assert self.world_size % hf_config.num_key_value_heads == 0 + assert self.world_size % num_kv_heads_cfg == 0 return 1 def _get_total_num_layers(self): @@ -1321,11 +1326,15 @@ def allocate_kv_cache(self, num_kvcache_blocks): self.num_physical_kvcache_blocks = ( num_kvcache_blocks * self.attn_metadata_builder.block_ratio ) - if hf_config.num_key_value_heads >= self.world_size: - assert hf_config.num_key_value_heads % self.world_size == 0 - num_kv_heads = hf_config.num_key_value_heads // self.world_size + num_kv_heads_cfg = getattr( + hf_config, "num_key_value_heads", + getattr(hf_config, "num_attention_groups", None) + ) + if num_kv_heads_cfg >= self.world_size: + assert num_kv_heads_cfg % self.world_size == 0 + num_kv_heads = num_kv_heads_cfg // self.world_size else: - assert self.world_size % hf_config.num_key_value_heads == 0 + assert self.world_size % num_kv_heads_cfg == 0 num_kv_heads = 1 # Calculate total number of layers (target + draft) diff --git a/atom/model_loader/loader.py b/atom/model_loader/loader.py index 219b93f59..7ebbba8a4 100644 --- a/atom/model_loader/loader.py +++ b/atom/model_loader/loader.py @@ -337,10 +337,21 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None: maybe_matching_name, f"mlp.experts.{hf_config.n_routed_experts}.", ) + # Check fused expert format before packed_modules_mapping to avoid + # expert weights (e.g. moe.gate_proj) being incorrectly matched + # by packed_modules_mapping entries (e.g. gate_proj -> gate_up_proj). + if detect_fused_expert_fn is not None and not is_fused_expert: + if detect_fused_expert_fn(name): + is_fused_expert = True + if get_fused_expert_mapping_fn is not None: + fused_expert_params_mapping = get_fused_expert_mapping_fn() for k in packed_modules_mapping: # We handle the experts below in expert_params_mapping if "mlp.experts." in name and name not in params_dict: continue + # Skip fused expert weights — handled below in expert loading path + if is_fused_expert and detect_fused_expert_fn is not None and detect_fused_expert_fn(name): + continue if k in name: packed_value = packed_modules_mapping[k] # Handle both tuple (fuse parameter) and list (shard parameter) diff --git a/atom/model_ops/attentions/aiter_attention.py b/atom/model_ops/attentions/aiter_attention.py index a8aa3a658..917987834 100644 --- a/atom/model_ops/attentions/aiter_attention.py +++ b/atom/model_ops/attentions/aiter_attention.py @@ -83,7 +83,7 @@ def __init__( else: max_qlen = 1 - num_head_k = max(1, hf_config.num_key_value_heads // get_tp_group().world_size) + num_head_k = max(1, getattr(hf_config, "num_key_value_heads", getattr(hf_config, "num_attention_groups", None)) // get_tp_group().world_size) ( (work_meta_data_size, work_meta_data_type), (work_indptr_size, work_indptr_type), @@ -240,7 +240,7 @@ def set_aiter_persistent_worker_buffers(self, bs: int): hf_config = config.hf_config num_query_heads = self.num_attention_heads num_kv_heads = max( - 1, hf_config.num_key_value_heads // get_tp_group().world_size + 1, getattr(hf_config, "num_key_value_heads", getattr(hf_config, "num_attention_groups", None)) // get_tp_group().world_size ) block_size = self.block_size diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 2315d371c..28d5d633d 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -485,7 +485,43 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight = atom_parameter(self._maybe_pad_weight(layer.w13_weight.data)) layer.w2_weight = atom_parameter(self._maybe_pad_weight(layer.w2_weight.data)) - # reshaping weights is required for aiter moe kernel. + + # gfx950 CK a16w16 stage2 requires inter_dim % 64 == 0. + # For tp=4 (inter=320) and tp=8 (inter=160), pad inter_dim up to the + # next multiple of 64. Zero padding is safe because fused_moe clips + # routed-weight contributions and zero-padded rows contribute nothing. + # Verified 2026-04-24: cos_sim >= 0.9999 for inter=160->192 and + # inter=320->384 vs torch reference. + w13 = layer.w13_weight.data # [E, 2*inter, hidden] + w2 = layer.w2_weight.data # [E, hidden, inter] + inter_dim = w2.shape[2] + # Stage1 dispatch: inter<=192 uses NPerBlock=64, inter>192 uses NPerBlock=128. + # Stage2 dispatch: inter>192 uses KPerBlock=64. + # So required alignment: 64 when inter<=192, 128 when inter>192. + # (inter=160->192 satisfies 192%64=0; inter=320->384 satisfies 384%128=0 and 384%64=0) + align = 64 if inter_dim <= 192 else 128 + inter_pad = (inter_dim + align - 1) // align * align + if inter_pad != inter_dim: + E, _, hidden = w13.shape + # pad w13: gate half [E, inter, hidden] and up half [E, inter, hidden] + w13_new = torch.zeros( + E, 2 * inter_pad, hidden, dtype=w13.dtype, device=w13.device + ) + w13_new[:, :inter_dim, :] = w13[:, :inter_dim, :] # gate + w13_new[:, inter_pad : inter_pad + inter_dim, :] = w13[:, inter_dim:, :] # up + # pad w2: [E, hidden, inter_pad] + w2_new = torch.zeros( + E, hidden, inter_pad, dtype=w2.dtype, device=w2.device + ) + w2_new[:, :, :inter_dim] = w2 + layer.w13_weight = atom_parameter(w13_new) + layer.w2_weight = atom_parameter(w2_new) + + # Shuffle weights for CK/ASM kernels. + # Previously skipped for gfx950 bf16 g1u1 on the assumption that the CK + # 2-stage preshuffle_off (NSwizzle=0) kernel expected un-shuffled weights. + # Verified 2026-04-23: preshuffle_off GEMM is wrong on gfx950; preshuffle_on + # (NSwizzle=1) is correct. Always shuffle so the right kernel path is used. shuffle_weights(layer.w13_weight, layer.w2_weight) def get_fused_moe_quant_config( @@ -998,6 +1034,13 @@ def apply( global_num_experts=global_num_experts, expert_map=expert_map, ) + _tp_size = getattr(self, 'tp_size', getattr(self.moe, 'tp_size', 1)) if hasattr(self, 'moe') else getattr(self, 'tp_size', 1) + _ep_size = getattr(self, 'ep_size', getattr(self.moe, 'ep_size', 1)) if hasattr(self, 'moe') else getattr(self, 'ep_size', 1) + if layer.reduce_results and (_tp_size > 1 or _ep_size > 1): + from aiter.dist.parallel_state import get_tp_group + _moe_result = get_tp_group().all_reduce( + _moe_result, ca_fp8_quant=False + ) return _moe_result assert ( @@ -1516,25 +1559,38 @@ def create_weights( block_n = 1 block_k = 32 tp_size = get_tp_group().world_size + # Pad intermediate_size_per_partition to the nearest multiple of + # block_n so that both the CK blockscale kernel alignment constraints + # and the block-quantization scale alignment constraints are satisfied. + # For tp=4, inter_dim=320 is not divisible by block_n=128; padding to + # 384 (=3×128) satisfies both stage1 NPerBlock=128 and block_n=128. + # The weight loader already supports partial loading into a padded + # buffer via the MXFP4 alignment path (_load_w13 / _load_w2). + padded_inter = ( + (intermediate_size_per_partition + block_n - 1) // block_n * block_n + ) # NOTE: To ensure proper alignment of the block-wise quantization # scales, the output_size of the weights for both the gate and up # layers must be divisible by block_n. # Required by column parallel or enabling merged weights - if intermediate_size_per_partition % block_n != 0: + if padded_inter % block_n != 0: raise ValueError( f"The output_size of gate's and up's weight = " - f"{intermediate_size_per_partition} is not divisible by " + f"{padded_inter} is not divisible by " f"weight quantization block_n = {block_n}." ) - if tp_size > 1 and intermediate_size_per_partition % block_k != 0: + if tp_size > 1 and padded_inter % block_k != 0: # Required by row parallel raise ValueError( f"The input_size of down's weight = " - f"{intermediate_size_per_partition} is not divisible by " + f"{padded_inter} is not divisible by " f"weight quantization block_k = {block_k}." ) # WEIGHTS + # Allocated at original (un-padded) size; inter_dim padding is applied + # later in _process_block_quant (after normalize, before shuffle_weights), + # mirroring the BF16 approach in UnquantizedFusedMoEMethod. w13_weight = atom_parameter( torch.empty( num_experts, @@ -1654,6 +1710,41 @@ def _process_block_quant(self, layer: nn.Module) -> None: assert self.quant_config.is_dynamic self._normalize_weights_and_scales(layer) + # Inter-dim padding for block-quantized FP8 (mirrors BF16 approach in + # UnquantizedFusedMoEMethod.process_weights_after_loading). + # When inter_dim is not a multiple of block_n (e.g. tp=4: 320 % 128 ≠ 0), + # zero-pad both weights to the nearest block_n multiple BEFORE shuffling. + # Padding area is zero so dequant(0, scale) = 0 is numerically safe. + # Scale tensors use ceil(inter/block_n) and are already shape-compatible. + inter_dim = layer.w2_weight.shape[-1] + block_n = 128 if self.quant_type == QuantType.per_1x128 else 32 + # FP8 blockscale stage2 requires KPerBlock=128 (gfx950 FP8 mfma KPack=32 constraint + # prevents KPerBlock=64). align must always be block_n(=128) so that inter_pad%128==0. + # Bug fix: previously used align=64 for inter<=192 (copied from BF16 path), but + # 192%128=64!=0 → stage2 kernel dispatch fails. Correct: always align to block_n. + # tp=8 inter=160 → 256 (3×128→no, ceil(160/128)*128=256); tp=4 inter=320 → 384. + align = block_n + inter_pad = (inter_dim + align - 1) // align * align + if inter_pad != inter_dim: + E = layer.w13_weight.shape[0] + hidden = layer.w13_weight.shape[-1] + w13 = layer.w13_weight.data + w13_new = torch.zeros( + E, 2 * inter_pad, hidden, + dtype=w13.dtype, device=w13.device + ) + w13_new[:, :inter_dim, :] = w13[:, :inter_dim, :] # gate + w13_new[:, inter_pad:inter_pad + inter_dim, :] = w13[:, inter_dim:, :] # up + layer.w13_weight = atom_parameter(w13_new) + + w2 = layer.w2_weight.data + w2_new = torch.zeros( + E, hidden, inter_pad, + dtype=w2.dtype, device=w2.device + ) + w2_new[:, :, :inter_dim] = w2 + layer.w2_weight = atom_parameter(w2_new) + if not self.need_normalize_e4m3fn_to_e4m3fnuz: layer.w13_weight = atom_parameter(layer.w13_weight.data) layer.w13_weight_scale = atom_parameter(layer.w13_weight_scale.data) @@ -1726,14 +1817,19 @@ def get_fused_moe_quant_config( a2_scale=layer.w2_input_scale, per_act_token_quant=True, ) + elif self.quant_type == QuantType.per_1x128: + block_shape = [128, 128] + elif self.quant_type == QuantType.per_1x32: + block_shape = [1, 32] else: - return fp8_w8a8_moe_quant_config( - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - block_shape=None, - ) + block_shape = None + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=block_shape, + ) @mark_trace(prefix="fp8_moe", torch_compile=False) def apply( @@ -2207,10 +2303,36 @@ def _load_w13( expert_shard_size = expert_data.shape[shard_dim] // 2 # Derive shard size from loaded_weight (unpadded checkpoint) to avoid # out-of-bounds when expert_data is padded (e.g. MXFP4 alignment). - load_shard_size = loaded_weight.shape[shard_dim] // self.tp_size - loaded_weight = loaded_weight.narrow( - shard_dim, load_shard_size * tp_rank, load_shard_size - ) + # Use ceil so that the last partial scale block (e.g. per_1x128 with + # inter=1280 and tp=4: 10 blocks / 4 = 2.5 → ceil=3) is included. + # Without ceil, the 3rd scale block is never copied and stays at the + # torch.ones() initial value of 1.0, causing ~5000× dequant error. + load_shard_size = ( + loaded_weight.shape[shard_dim] + self.tp_size - 1 + ) // self.tp_size + start = load_shard_size * tp_rank + # When D < tp_size (e.g. per_1x128 scale block count smaller than + # tp_size, observed at tp=8 with inter=1280 → D=10), the ceil split + # gives some trailing ranks start >= D so they hold no slice of the + # loaded tensor. Skip narrow + copy_ for those ranks; the rank's + # slice of expert_data stays at its initialised value (0 for weight, + # 1.0 for scale) and the rank contributes a no-op to the column + # gather / row reduction. + if start >= loaded_weight.shape[shard_dim]: + # FP8 scale tensors are torch.ones() initialised. If we leave the + # trailing rank's slice at 1.0, the downstream FP8 dequant multiplies + # the (uninitialised) fp8 weight by 1.0 instead of the correct + # quantization scale, contaminating the column gather / row reduction. + # Zero the slot so dequant produces 0 and the rank contributes a + # true no-op (matches MXFP4 scale init at moe.py:776,813). + if expert_data.dtype == torch.float32: + if shard_id == "w1": + expert_data.narrow(shard_dim, 0, expert_shard_size).zero_() + else: + expert_data.narrow(shard_dim, expert_shard_size, expert_shard_size).zero_() + return + size = min(load_shard_size, loaded_weight.shape[shard_dim] - start) + loaded_weight = loaded_weight.narrow(shard_dim, start, size) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": @@ -2246,10 +2368,24 @@ def _load_w2( if not load_full: # Derive shard size from loaded_weight (unpadded checkpoint) to # avoid out-of-bounds when expert_data is padded (e.g. MXFP4). - load_shard_size = loaded_weight.shape[shard_dim] // self.tp_size - loaded_weight = loaded_weight.narrow( - shard_dim, load_shard_size * tp_rank, load_shard_size - ) + # Use ceil (same reason as _load_w13: partial last scale block). + load_shard_size = ( + loaded_weight.shape[shard_dim] + self.tp_size - 1 + ) // self.tp_size + start = load_shard_size * tp_rank + # See _load_w13 comment above: when D < tp_size the ceil split + # leaves trailing ranks with no slice; skip narrow + copy_. + if start >= loaded_weight.shape[shard_dim]: + # Zero the scale slice so dequant=0 instead of multiplying by + # stale init=1.0; see _load_w13 comment for full rationale. + if expert_data.dtype == torch.float32: + if load_shard_size != shard_size: + expert_data.narrow(shard_dim, 0, load_shard_size).zero_() + else: + expert_data.zero_() + return + size = min(load_shard_size, loaded_weight.shape[shard_dim] - start) + loaded_weight = loaded_weight.narrow(shard_dim, start, size) if load_shard_size != shard_size: expert_data = expert_data.narrow(shard_dim, 0, load_shard_size) # w2, down_proj: Load into only logical weight of w2. @@ -2556,6 +2692,14 @@ def select_experts( routed_scaling_factor: float = 1.0, ): + # Model-provided custom routing (e.g. sigmoid + bias + scaling for Step-3.5) + if custom_routing_function is not None: + return custom_routing_function( + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + # DeekSeekv2 uses grouped_top_k if use_grouped_topk: assert topk_group is not None @@ -2697,6 +2841,7 @@ def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor) hidden_states = naive_multicast(hidden_states, cu_tokens_across_dp_cpu) router_logits = naive_multicast(router_logits, cu_tokens_across_dp_cpu) + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, diff --git a/atom/models/step3p5.py b/atom/models/step3p5.py new file mode 100644 index 000000000..c5e97c995 --- /dev/null +++ b/atom/models/step3p5.py @@ -0,0 +1,918 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Inference-only Step-3.5 (Flash) model. + +Step-3.5 is a sparse MoE transformer with: + - 45 decoder layers, hidden_size=4096, head_dim=128 + - GQA with two attention configs: full_attention (64 q heads, 8 kv groups) + and sliding_attention (96 q heads, 8 kv groups, window=512) + - 3:1 sliding window pattern (1 full + 3 sliding) + - Per-layer rope_theta and partial_rotary_factor + - QK RMSNorm (zero-centered, i.e. weight * (1 + param)) + - Head-wise attention gating via g_proj (sigmoid) + - MoE on layers 3-44: 288 routed experts + 1 shared expert, top-8, + sigmoid routing with learnable router bias + - Dense MLP on layers 0-2 + - Per-layer SwiGLU clamp limits + - Multi-token prediction (MTP) with num_nextn_predict_layers=3 +""" + +import os +from typing import Any, Optional, Union + +import torch +from aiter import ActivationType +from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size +from aiter.rotary_embedding import get_rope +from atom.config import Config, QuantizationConfig +from atom.model_ops.activation import SiluAndMul +from atom.model_ops.base_attention import Attention +from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding +from atom.model_ops.layernorm import GemmaRMSNorm as Step3p5RMSNorm +from atom.model_ops.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from atom.model_ops.moe import FusedMoE +from atom.model_ops.topK import is_rocm_aiter_fusion_shared_expert_enabled +from atom.models.utils import ( + IntermediateTensors, + PPMissingLayer, + extract_layer_index, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) +from atom.utils.decorators import support_torch_compile +from torch import nn +from transformers import PretrainedConfig + + +def _uses_swiglustep_at_layer(config: PretrainedConfig, layer_idx: Optional[int]) -> bool: + """Return True iff the routed FusedMoE at this layer needs the SwigluStep + activation (i.e. ``swiglu_limits[layer_idx] > 0``). + + The CK kernel hard-codes the clamp at 7.0; Step-3.5-Flash uses 7.0 at + layers 43 and 44, which is why the kernel is only valid at those layers. + Other layers must keep the plain Silu path. + """ + if layer_idx is None: + return False + # Toggle-off bit: ATOM_DISABLE_SWIGLUSTEP=1 forces plain Silu at every + # layer (verification helper only). + if os.environ.get("ATOM_DISABLE_SWIGLUSTEP"): + return False + swiglu_limits = getattr(config, "swiglu_limits", None) + if not swiglu_limits or layer_idx >= len(swiglu_limits): + return False + return bool(swiglu_limits[layer_idx]) + + +def _fuse_shared_at_layer(config: PretrainedConfig, layer_idx: Optional[int]) -> bool: + """Whether to fuse the shared expert into the routed FusedMoE at this layer. + + R5 mitigation: at SwigluStep layers the kernel clamps every expert at 7.0, + but the shared expert may use a different clamp (e.g. 16 at layer 44 or 0 + at layer 43). Therefore the shared expert MUST stay on the dense path at + every SwigluStep layer, even when the global aiter fusion is enabled. + """ + # ATOM_FORCE_FUSE_SHARED=1 always fuses the shared expert into the + # routed kernel (verification helper: bypass R5 mitigation). + if os.environ.get("ATOM_FORCE_FUSE_SHARED"): + return is_rocm_aiter_fusion_shared_expert_enabled() + return ( + is_rocm_aiter_fusion_shared_expert_enabled() + and not _uses_swiglustep_at_layer(config, layer_idx) + ) + + +# --------------------------------------------------------------------------- +# MLP (dense, used for first few layers and shared expert) +# --------------------------------------------------------------------------- + + +class Step3p5MLP(nn.Module): + """Dense SwiGLU MLP with optional activation clamping.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + reduce_results: bool = True, + clamp_limit: Optional[float] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + # 0.0 means no clamping (disabled), only apply if > 0 + self.clamp_limit = clamp_limit if (clamp_limit is not None and clamp_limit > 0) else None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.gate_up_proj(x) + if self.clamp_limit is not None: + # Match HF: clamp AFTER silu activation on gate, symmetric on up + # gate_proj output is first half, up_proj output is second half + half = x.shape[-1] // 2 + gate, up = x[..., :half], x[..., half:] + gate = torch.nn.functional.silu(gate).clamp(max=self.clamp_limit) + up = up.clamp(min=-self.clamp_limit, max=self.clamp_limit) + x = self.down_proj(gate * up) + else: + x = self.act_fn(x) + x = self.down_proj(x) + return x + + +# --------------------------------------------------------------------------- +# MoE block (routed experts + shared expert) +# --------------------------------------------------------------------------- + + +class Step3p5MoE(nn.Module): + """Sparse MoE block for Step-3.5. + + Checkpoint weight layout under ``layers.{i}.moe.*``: + - gate.weight (router linear) + - router_bias (learnable additive bias for sigmoid routing) + - gate_proj.weight (per-expert, shape [num_experts, intermediate, hidden]) + - up_proj.weight (per-expert) + - down_proj.weight (per-expert) + + The FusedMoE kernel maps these via ``get_expert_mapping``. + """ + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + + num_experts: int = config.moe_num_experts + top_k: int = config.moe_top_k + moe_intermediate_size: int = config.moe_intermediate_size + + # Per-layer SwiGLU clamp limit for routed experts. + # Step-3.5 applies clamp(x, -limit, limit) after gate_up_proj and + # before the SwiGLU activation inside each expert. The CK kernel + # implements this as ``ActivationType.SwigluStep`` with a hard-coded + # ±7 clamp; Step-3.5-Flash uses 7 at layers 43-44 only. + layer_idx = extract_layer_index(prefix) if prefix else None + self._layer_idx = layer_idx + swiglu_limits = getattr(config, "swiglu_limits", None) + self.clamp_limit = ( + swiglu_limits[layer_idx] + if (swiglu_limits and layer_idx is not None and swiglu_limits[layer_idx] > 0) + else None + ) + self._uses_swiglustep = self.clamp_limit is not None + self._activation = ( + ActivationType.SwigluStep if self._uses_swiglustep else ActivationType.Silu + ) + + # Router --------------------------------------------------------- + self.gate = ReplicatedLinear( + self.hidden_size, + num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + # Learnable router bias (added to sigmoid probs before top-k) + self.router_bias = nn.Parameter( + torch.zeros(num_experts, dtype=torch.float32), + requires_grad=False, + ) + + self.routed_scaling_factor = getattr( + config, "moe_router_scaling_factor", 1.0 + ) + self._need_fp32_gate = getattr(config, "need_fp32_gate", False) + + # Routed experts (fused MoE kernel) -------------------------------- + # R5 mitigation: at SwigluStep layers we MUST NOT fuse the shared + # expert into the routed FusedMoE (the kernel hard-codes ±7 clamp, + # but the shared expert uses a different limit, e.g. 16 at layer 44 + # or 0 at layer 43). Fall back to the dense Step3p5MLP path there. + self._fuse_shared = _fuse_shared_at_layer(config, layer_idx) + n_shared = 1 if self._fuse_shared else 0 + self._n_shared_fused = n_shared # 1 when shared expert is fused as expert num_experts + self.experts = FusedMoE( + num_experts=num_experts + n_shared, + top_k=top_k + n_shared, # +1 so kernel selects top_k routed + 1 shared + hidden_size=self.hidden_size, + intermediate_size=moe_intermediate_size, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + prefix=f"{prefix}.experts", + custom_routing_function=self._routing_function, + config=config, + activation=self._activation, + ) + + def _routing_function( + self, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + ): + """Sigmoid routing with additive bias and scaling. + + When the shared expert is fused (self._n_shared_fused == 1), topk is + top_k_routed + 1. We select top_k_routed routed experts and append + the shared expert (index num_routed_experts) with weight 1.0. + """ + n_shared = self._n_shared_fused + top_k_routed = topk - n_shared # number of routed experts to pick + + gate_prob = torch.sigmoid(gating_output.float()) + gate_prob_biased = gate_prob + self.router_bias.unsqueeze(0) + _, indices = torch.topk(gate_prob_biased, k=top_k_routed, dim=1) + topk_prob = torch.gather(gate_prob, 1, indices) + if renormalize: + topk_prob = topk_prob / ( + topk_prob.sum(dim=-1, keepdim=True) + 1e-20 + ) + topk_prob = topk_prob * self.routed_scaling_factor + + if n_shared > 0: + # Append shared expert (always selected, weight=1.0) + T = gating_output.shape[0] + num_routed = gating_output.shape[1] # 288 + shared_ids = torch.full( + (T, n_shared), num_routed, dtype=torch.int32, device=gating_output.device + ) + shared_weights = torch.ones( + (T, n_shared), dtype=torch.float32, device=gating_output.device + ) + topk_prob = torch.cat([topk_prob, shared_weights], dim=1) + indices = torch.cat([indices.to(torch.int32), shared_ids], dim=1) + return topk_prob, indices + + return topk_prob, indices.to(torch.int32) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + + # Router logits must be computed in fp32 (need_fp32_gate=True in config) + if getattr(self, "_need_fp32_gate", True): + router_logits = torch.nn.functional.linear( + hidden_states.float(), self.gate.weight.float() + ) + else: + router_logits = self.gate(hidden_states) + + # Routed experts. At SwigluStep layers (43-44) the FusedMoE was + # constructed with ``activation=ActivationType.SwigluStep`` so the CK + # kernel applies ``silu(g).clamp(max=7) * up.clamp(±7)`` per expert. + routed_out = self.experts(hidden_states, router_logits) + + return routed_out.view(orig_shape) + + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- + + +class Step3p5Attention(nn.Module): + """GQA attention for Step-3.5. + + Key differences from vanilla LLaMA attention: + - Per-layer rope_theta and partial_rotary_factor (from config lists). + - Two attention head configurations depending on full vs sliding. + - QK RMSNorm (zero-centered / GemmaRMSNorm style). + - Head-wise attention gating via g_proj (sigmoid). + """ + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: str = "bf16", + prefix: str = "", + layer_num: int = 0, + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = config.hidden_size + + # Determine layer type and head counts ---------------------------- + layer_types = getattr(config, "layer_types", []) + is_sliding = ( + layer_types[layer_idx] == "sliding_attention" if layer_types else False + ) + attn_other = getattr(config, "attention_other_setting", None) + + if is_sliding and attn_other is not None: + self.total_num_heads = attn_other["num_attention_heads"] + self.total_num_kv_heads = attn_other["num_attention_groups"] + else: + self.total_num_heads = config.num_attention_heads + self.total_num_kv_heads = config.num_attention_groups + + self.head_dim = getattr(config, "head_dim", 128) + + tp_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim ** -0.5 + + # RoPE configuration ----------------------------------------------- + rope_theta_cfg = getattr(config, "rope_theta", 10000.0) + if isinstance(rope_theta_cfg, list): + rope_theta = rope_theta_cfg[layer_idx] + else: + rope_theta = rope_theta_cfg + + partial_rotary_factors = getattr(config, "partial_rotary_factors", None) + if partial_rotary_factors is not None: + partial_rotary_factor = partial_rotary_factors[layer_idx] + else: + partial_rotary_factor = 1.0 + + rotary_dim = int(self.head_dim * partial_rotary_factor) + + max_position_embeddings = getattr( + config, "max_position_embeddings", 262144 + ) + + # Determine rope_scaling for this layer + rope_scaling = getattr(config, "rope_scaling", None) + yarn_only_types = getattr(config, "yarn_only_types", None) + if yarn_only_types and layer_types: + layer_type = layer_types[layer_idx] + if layer_type not in yarn_only_types: + rope_scaling = None + + # Projections ------------------------------------------------------- + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + # QK Norm (zero-centered RMSNorm) ----------------------------------- + rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5) + self.q_norm = Step3p5RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = Step3p5RMSNorm(self.head_dim, eps=rms_norm_eps) + + # Head-wise attention gate ------------------------------------------- + self.use_head_wise_attn_gate = getattr( + config, "use_head_wise_attn_gate", False + ) + if self.use_head_wise_attn_gate: + self.g_proj = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=self.total_num_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.g_proj", + ) + + # Rotary embedding --------------------------------------------------- + # Note: rotary_dim is already computed as head_dim * partial_rotary_factor, + # so we do NOT pass partial_rotary_factor to get_rope (which would apply it twice). + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + + # Sliding window and sink tokens per layer --------------------------- + # DEBUG: temporarily disable sliding window to test if pa_decode_gluon is the issue + sliding_window = None + sinks = None + if is_sliding and not os.environ.get("ATOM_STEP3P5_NO_SLIDING"): + sliding_window = getattr(config, "sliding_window", None) + sink_size = getattr(config, "sink", 0) + if sink_size > 0: + sinks = nn.Parameter( + torch.empty(self.num_heads, requires_grad=False) + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + kv_cache_dtype=cache_config, + layer_num=layer_num, + per_layer_sliding_window=sliding_window, + sinks=sinks, + prefix=f"{prefix}.attn", + rotary_emb=self.rotary_emb, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + import os + debug_nan = os.environ.get("ATOM_DEBUG_NAN2") + + qkv = self.qkv_proj(hidden_states) + q, k, v = torch.split( + qkv, [self.q_size, self.kv_size, self.kv_size], dim=-1 + ) + if debug_nan and (q.isnan().any() or k.isnan().any() or v.isnan().any()): + print(f"[NAN-ATTN] qkv has NaN: q={q.isnan().any()} k={k.isnan().any()} v={v.isnan().any()}") + + # QK Norm – apply per-head RMSNorm + # Reshape to (..., num_heads, head_dim), apply norm, reshape back + q = self.q_norm(q.reshape(*q.shape[:-1], -1, self.head_dim)).flatten(-2) + k = self.k_norm(k.reshape(*k.shape[:-1], -1, self.head_dim)).flatten(-2) + if debug_nan and (q.isnan().any() or k.isnan().any()): + print(f"[NAN-ATTN] after qk_norm NaN: q={q.isnan().any()} k={k.isnan().any()}") + + attn_output = self.attn(q, k, v, positions) + if debug_nan and attn_output.isnan().any(): + print(f"[NAN-ATTN] attn output has NaN, num_heads={self.num_heads}, num_kv_heads={self.num_kv_heads}, q_size={self.q_size}, kv_size={self.kv_size}, attn_output.shape={attn_output.shape}") + + # Head-wise gating + if self.use_head_wise_attn_gate: + gate = self.g_proj(hidden_states) # (tokens, num_heads_tp) + if debug_nan and gate.isnan().any(): + print(f"[NAN-ATTN] gate (g_proj) has NaN, gate.shape={gate.shape}") + # gate: (tokens, num_heads_tp) -> (tokens, num_heads_tp, 1) + gate = torch.sigmoid(gate).unsqueeze(-1) + reshaped = attn_output.reshape(*attn_output.shape[:-1], -1, self.head_dim) + if debug_nan: + print(f"[NAN-ATTN] attn_output.shape={attn_output.shape} reshaped.shape={reshaped.shape} gate.shape={gate.shape}") + attn_output = (reshaped * gate).flatten(-2) + if debug_nan and attn_output.isnan().any(): + print(f"[NAN-ATTN] after gate multiply has NaN") + + output = self.o_proj(attn_output) + if debug_nan and output.isnan().any(): + print(f"[NAN-ATTN] o_proj output has NaN") + return output + + +# --------------------------------------------------------------------------- +# Decoder Layer +# --------------------------------------------------------------------------- + + +class Step3p5DecoderLayer(nn.Module): + """Single decoder layer for Step-3.5. + + - Layers 0-2: dense MLP + - Layers 3-44: MoE (288 routed + 1 shared) + """ + + def __init__( + self, + config: PretrainedConfig, + cache_config: str = "bf16", + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + layer_num: int = 0, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + layer_idx = extract_layer_index(prefix) + + # Attention + self.self_attn = Step3p5Attention( + config=config, + quant_config=quant_config, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + layer_num=layer_num, + ) + + # FFN: dense MLP or MoE depending on layer index + moe_layers_enum = getattr(config, "moe_layers_enum", None) + if moe_layers_enum is not None: + if isinstance(moe_layers_enum, str): + moe_layers_idx = [ + int(i) for i in moe_layers_enum.strip().split(",") + ] + else: + moe_layers_idx = list(moe_layers_enum) + else: + moe_layers_idx = list(range(3, config.num_hidden_layers)) + + self.is_moe_layer = layer_idx in moe_layers_idx + + # Per-layer SwiGLU clamp limits + swiglu_limits = getattr(config, "swiglu_limits", None) + swiglu_limits_shared = getattr(config, "swiglu_limits_shared", None) + clamp_limit = swiglu_limits[layer_idx] if swiglu_limits else None + clamp_limit_shared = swiglu_limits_shared[layer_idx] if swiglu_limits_shared else None + + if self.is_moe_layer: + self.moe = Step3p5MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.moe", + ) + # Shared expert (always active, sibling of moe in checkpoint). + # Per-layer fuse decision: SwigluStep layers must keep the shared + # expert on the dense path because the routed CK kernel hard-codes + # the clamp at 7 (see _fuse_shared_at_layer). + if not _fuse_shared_at_layer(config, layer_idx): + self.share_expert = Step3p5MLP( + hidden_size=self.hidden_size, + intermediate_size=config.share_expert_dim, + quant_config=quant_config, + prefix=f"{prefix}.share_expert", + clamp_limit=clamp_limit_shared, + ) + else: + self.share_expert = None + else: + self.mlp = Step3p5MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + clamp_limit=clamp_limit_shared, # HF uses swiglu_limits_shared for dense MLP + ) + + # Layer norms (zero-centered RMSNorm) + rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5) + self.input_layernorm = Step3p5RMSNorm( + config.hidden_size, eps=rms_norm_eps + ) + self.post_attention_layernorm = Step3p5RMSNorm( + config.hidden_size, eps=rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual + ) + + hidden_states = self.self_attn( + positions=positions, hidden_states=hidden_states + ) + + # FFN + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + + if self.is_moe_layer: + moe_output = self.moe(hidden_states) + if self.share_expert is not None: + shared_output = self.share_expert(hidden_states) + hidden_states = moe_output + shared_output + else: + hidden_states = moe_output + else: + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +# --------------------------------------------------------------------------- +# Full Model +# --------------------------------------------------------------------------- + + +@support_torch_compile +class Step3p5Model(nn.Module): + def __init__( + self, + atom_config: Config, + prefix: str = "", + ): + super().__init__() + config = atom_config.hf_config + self.config = config + cache_config = atom_config.kv_cache_dtype + quant_config = atom_config.quant_config + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank or ( + getattr(config, "tie_word_embeddings", False) + and get_pp_group().is_last_rank + ): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix, layer_num=None: Step3p5DecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + layer_num=layer_num, + ), + prefix=f"{prefix}.layers", + layer_num_offset=0, + ) + + rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5) + if get_pp_group().is_last_rank: + self.norm = Step3p5RMSNorm(config.hidden_size, eps=rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + import os as _dbg_os + _dbg_layer = _dbg_os.environ.get("ATOM_DEBUG_LAYER") + for i, layer in enumerate(self.layers[self.start_layer : self.end_layer]): + hidden_states, residual = layer(positions, hidden_states, residual) + if _dbg_layer: + _hs_nan = torch.isnan(hidden_states).any().item() + _res_nan = torch.isnan(residual).any().item() if residual is not None else False + if _hs_nan or _res_nan: + print(f"[LAYER NaN] layer={i} hs_nan={_hs_nan} res_nan={_res_nan} hs_norm={hidden_states.float().norm():.3f}", flush=True) + break + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +# --------------------------------------------------------------------------- +# CausalLM wrapper +# --------------------------------------------------------------------------- + + +class Step3p5ForCausalLM(nn.Module): + """Step-3.5 model with language modelling head.""" + + packed_modules_mapping = { + "q_proj": ("qkv_proj", "q"), + "k_proj": ("qkv_proj", "k"), + "v_proj": ("qkv_proj", "v"), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + atom_config: Config, + prefix: str = "", + ): + super().__init__() + config = atom_config.hf_config + self.config = config + + self.model = Step3p5Model( + atom_config=atom_config, + prefix=maybe_prefix(prefix, "model"), + ) + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if getattr(config, "tie_word_embeddings", False): + self.lm_head.weight = self.model.embed_tokens.weight + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + logits = self.lm_head(hidden_states) + return logits + + def detect_fused_expert_format(self, weight_name: str) -> bool: + """Step-3.5-Flash expert weights are flat: moe.gate_proj.weight [E, I, H]. + When shared expert fusion is enabled, share_expert weights are also loaded + as expert N in FusedMoE; otherwise they are loaded as a regular MLP. + + Per-layer override: at SwigluStep layers (43-44) the shared expert is + not fused, so its weights must take the dense path even when the + global aiter fusion flag is on. + """ + is_routed_expert = ( + ".moe.gate_proj" in weight_name + or ".moe.up_proj" in weight_name + or ".moe.down_proj" in weight_name + ) + if is_routed_expert: + return True + is_share_expert = ( + ".share_expert.gate_proj" in weight_name + or ".share_expert.up_proj" in weight_name + or ".share_expert.down_proj" in weight_name + ) + if is_share_expert: + layer_idx = extract_layer_index(weight_name) + return _fuse_shared_at_layer(self.config, layer_idx) + return False + + def get_fused_expert_mapping(self) -> list[tuple[str, str, str]]: + """Mapping from flat checkpoint names to FusedMoE parameter names. + + Weight names include the '.weight' suffix from the checkpoint so that + the replace() in loader.py produces the correct param name without the + extra '.weight' tail (e.g. 'moe.gate_proj.weight' -> 'moe.experts.w13_weight'). + """ + mapping = [ + ("moe.experts.w13_weight", "moe.gate_proj.weight", "w1"), + ("moe.experts.w13_weight", "moe.up_proj.weight", "w3"), + ("moe.experts.w2_weight", "moe.down_proj.weight", "w2"), + ] + if is_rocm_aiter_fusion_shared_expert_enabled(): + mapping += [ + ("moe.experts.w13_weight", "share_expert.gate_proj.weight", "w1"), + ("moe.experts.w13_weight", "share_expert.up_proj.weight", "w3"), + ("moe.experts.w2_weight", "share_expert.down_proj.weight", "w2"), + ] + return mapping + + def load_fused_expert_weights( + self, + original_name: str, + name: str, + params_dict: dict, + loaded_weight: torch.Tensor, + shard_id: str, + num_experts: int, + ) -> bool: + """Load flat expert weights [E, I, H] into FusedMoE per-expert params. + + For routed experts: loaded_weight is [num_experts, ...], loaded per-expert. + For shared expert: loaded_weight is [I, H] or [H, I], loaded as expert num_experts. + """ + # num_experts from loader may be 0 if hf_config uses non-standard attr name + if num_experts == 0: + num_experts = self.config.moe_num_experts + + if name not in params_dict: + return False + param = params_dict[name] + weight_loader = param.weight_loader + loaded_local_expert = False + + is_share_expert = "share_expert" in original_name + + if is_share_expert: + # Defensive: if this layer keeps the shared expert dense (e.g. + # SwigluStep layers 43-44), do not route it through FusedMoE. + layer_idx = extract_layer_index(original_name) + if not _fuse_shared_at_layer(self.config, layer_idx): + return False + # Shared expert is loaded as expert index num_experts (288) + expert_id = num_experts + try: + success = weight_loader( + param, + loaded_weight, + name, + shard_id, + expert_id, + return_success=True, + ) + if success: + loaded_local_expert = True + except TypeError: + weight_loader(param, loaded_weight, name, shard_id, expert_id) + loaded_local_expert = True + else: + for expert_id in range(num_experts): + try: + success = weight_loader( + param, + loaded_weight[expert_id], + name, + shard_id, + expert_id, + return_success=True, + ) + if success: + loaded_local_expert = True + except TypeError: + weight_loader(param, loaded_weight[expert_id], name, shard_id, expert_id) + loaded_local_expert = True + + return loaded_local_expert + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + """Return the expert parameter mapping for weight loading. + + Note: Step-3.5-Flash uses flat expert weights in the checkpoint + (moe.gate_proj.weight etc.), so get_expert_mapping is used only + as a sentinel to enable the expert loading path in loader.py. + The actual loading is handled by load_fused_expert_weights. + """ + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.moe_num_experts + + ( + 1 if is_rocm_aiter_fusion_shared_expert_enabled() else 0 + ), + )