Skip to content
Draft
12 changes: 7 additions & 5 deletions atom/examples/simple_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
parser.add_argument(
"--temperature", type=float, default=0.6, help="temperature for sampling"
)
parser.add_argument(
"--max-tokens", type=int, default=256, help="max tokens to generate"
)


def generate_cuda_graph_sizes(max_size):
Expand Down Expand Up @@ -46,9 +49,11 @@ def main():
engine_args = EngineArgs.from_cli_args(args)
llm = engine_args.create_engine()

tokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer = AutoTokenizer.from_pretrained(
args.model, trust_remote_code=getattr(args, "trust_remote_code", False)
)

sampling_params = SamplingParams(temperature=args.temperature, max_tokens=256)
sampling_params = SamplingParams(temperature=args.temperature, max_tokens=args.max_tokens)

prompts = [
tokenizer.apply_chat_template(
Expand All @@ -60,9 +65,6 @@ def main():
for prompt in prompts
]
print("This is prompts:", prompts)
# print("Warming up...")
# _ = llm.generate(["warmup"], sampling_params)
# print("Warm up done")

print("\n" + "=" * 70)
print("Starting profiling...")
Expand Down
25 changes: 17 additions & 8 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"Qwen3_5MoeForConditionalGeneration": "atom.models.qwen3_5.Qwen3_5MoeForConditionalGenerationTextOnly",
"KimiK25ForConditionalGeneration": "atom.models.kimi_k25.KimiK25ForCausalLM",
"MiniMaxM2ForCausalLM": "atom.models.minimax_m2.MiniMaxM2ForCausalLM",
"Step3p5ForCausalLM": "atom.models.step3p5.Step3p5ForCausalLM",
"MiMoV2FlashForCausalLM": "atom.models.mimo_v2_flash.MiMoV2FlashForCausalLM",
}
# seed = 34567
Expand Down Expand Up @@ -1027,11 +1028,15 @@ def allocate_forward_vars(self):
def _get_num_kv_heads(self):
"""Return the per-rank number of KV heads."""
hf_config = self.config.hf_config
if hf_config.num_key_value_heads >= self.world_size:
assert hf_config.num_key_value_heads % self.world_size == 0
return hf_config.num_key_value_heads // self.world_size
num_kv_heads_cfg = getattr(
hf_config, "num_key_value_heads",
getattr(hf_config, "num_attention_groups", None)
)
if num_kv_heads_cfg >= self.world_size:
assert num_kv_heads_cfg % self.world_size == 0
return num_kv_heads_cfg // self.world_size
else:
assert self.world_size % hf_config.num_key_value_heads == 0
assert self.world_size % num_kv_heads_cfg == 0
return 1

def _get_total_num_layers(self):
Expand Down Expand Up @@ -1321,11 +1326,15 @@ def allocate_kv_cache(self, num_kvcache_blocks):
self.num_physical_kvcache_blocks = (
num_kvcache_blocks * self.attn_metadata_builder.block_ratio
)
if hf_config.num_key_value_heads >= self.world_size:
assert hf_config.num_key_value_heads % self.world_size == 0
num_kv_heads = hf_config.num_key_value_heads // self.world_size
num_kv_heads_cfg = getattr(
hf_config, "num_key_value_heads",
getattr(hf_config, "num_attention_groups", None)
)
if num_kv_heads_cfg >= self.world_size:
assert num_kv_heads_cfg % self.world_size == 0
num_kv_heads = num_kv_heads_cfg // self.world_size
else:
assert self.world_size % hf_config.num_key_value_heads == 0
assert self.world_size % num_kv_heads_cfg == 0
num_kv_heads = 1

# Calculate total number of layers (target + draft)
Expand Down
11 changes: 11 additions & 0 deletions atom/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,21 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None:
maybe_matching_name,
f"mlp.experts.{hf_config.n_routed_experts}.",
)
# Check fused expert format before packed_modules_mapping to avoid
# expert weights (e.g. moe.gate_proj) being incorrectly matched
# by packed_modules_mapping entries (e.g. gate_proj -> gate_up_proj).
if detect_fused_expert_fn is not None and not is_fused_expert:
if detect_fused_expert_fn(name):
is_fused_expert = True
if get_fused_expert_mapping_fn is not None:
fused_expert_params_mapping = get_fused_expert_mapping_fn()
for k in packed_modules_mapping:
# We handle the experts below in expert_params_mapping
if "mlp.experts." in name and name not in params_dict:
continue
# Skip fused expert weights — handled below in expert loading path
if is_fused_expert and detect_fused_expert_fn is not None and detect_fused_expert_fn(name):
continue
if k in name:
packed_value = packed_modules_mapping[k]
# Handle both tuple (fuse parameter) and list (shard parameter)
Expand Down
4 changes: 2 additions & 2 deletions atom/model_ops/attentions/aiter_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
else:
max_qlen = 1

num_head_k = max(1, hf_config.num_key_value_heads // get_tp_group().world_size)
num_head_k = max(1, getattr(hf_config, "num_key_value_heads", getattr(hf_config, "num_attention_groups", None)) // get_tp_group().world_size)
(
(work_meta_data_size, work_meta_data_type),
(work_indptr_size, work_indptr_type),
Expand Down Expand Up @@ -240,7 +240,7 @@ def set_aiter_persistent_worker_buffers(self, bs: int):
hf_config = config.hf_config
num_query_heads = self.num_attention_heads
num_kv_heads = max(
1, hf_config.num_key_value_heads // get_tp_group().world_size
1, getattr(hf_config, "num_key_value_heads", getattr(hf_config, "num_attention_groups", None)) // get_tp_group().world_size
)
block_size = self.block_size

Expand Down
185 changes: 165 additions & 20 deletions atom/model_ops/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,43 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

layer.w13_weight = atom_parameter(self._maybe_pad_weight(layer.w13_weight.data))
layer.w2_weight = atom_parameter(self._maybe_pad_weight(layer.w2_weight.data))
# reshaping weights is required for aiter moe kernel.

# gfx950 CK a16w16 stage2 requires inter_dim % 64 == 0.
# For tp=4 (inter=320) and tp=8 (inter=160), pad inter_dim up to the
# next multiple of 64. Zero padding is safe because fused_moe clips
# routed-weight contributions and zero-padded rows contribute nothing.
# Verified 2026-04-24: cos_sim >= 0.9999 for inter=160->192 and
# inter=320->384 vs torch reference.
w13 = layer.w13_weight.data # [E, 2*inter, hidden]
w2 = layer.w2_weight.data # [E, hidden, inter]
inter_dim = w2.shape[2]
# Stage1 dispatch: inter<=192 uses NPerBlock=64, inter>192 uses NPerBlock=128.
# Stage2 dispatch: inter>192 uses KPerBlock=64.
# So required alignment: 64 when inter<=192, 128 when inter>192.
# (inter=160->192 satisfies 192%64=0; inter=320->384 satisfies 384%128=0 and 384%64=0)
align = 64 if inter_dim <= 192 else 128
inter_pad = (inter_dim + align - 1) // align * align
if inter_pad != inter_dim:
E, _, hidden = w13.shape
# pad w13: gate half [E, inter, hidden] and up half [E, inter, hidden]
w13_new = torch.zeros(
E, 2 * inter_pad, hidden, dtype=w13.dtype, device=w13.device
)
w13_new[:, :inter_dim, :] = w13[:, :inter_dim, :] # gate
w13_new[:, inter_pad : inter_pad + inter_dim, :] = w13[:, inter_dim:, :] # up
# pad w2: [E, hidden, inter_pad]
w2_new = torch.zeros(
E, hidden, inter_pad, dtype=w2.dtype, device=w2.device
)
w2_new[:, :, :inter_dim] = w2
layer.w13_weight = atom_parameter(w13_new)
layer.w2_weight = atom_parameter(w2_new)

# Shuffle weights for CK/ASM kernels.
# Previously skipped for gfx950 bf16 g1u1 on the assumption that the CK
# 2-stage preshuffle_off (NSwizzle=0) kernel expected un-shuffled weights.
# Verified 2026-04-23: preshuffle_off GEMM is wrong on gfx950; preshuffle_on
# (NSwizzle=1) is correct. Always shuffle so the right kernel path is used.
shuffle_weights(layer.w13_weight, layer.w2_weight)

def get_fused_moe_quant_config(
Expand Down Expand Up @@ -998,6 +1034,13 @@ def apply(
global_num_experts=global_num_experts,
expert_map=expert_map,
)
_tp_size = getattr(self, 'tp_size', getattr(self.moe, 'tp_size', 1)) if hasattr(self, 'moe') else getattr(self, 'tp_size', 1)
_ep_size = getattr(self, 'ep_size', getattr(self.moe, 'ep_size', 1)) if hasattr(self, 'moe') else getattr(self, 'ep_size', 1)
if layer.reduce_results and (_tp_size > 1 or _ep_size > 1):
from aiter.dist.parallel_state import get_tp_group
_moe_result = get_tp_group().all_reduce(
_moe_result, ca_fp8_quant=False
)
return _moe_result

assert (
Expand Down Expand Up @@ -1516,25 +1559,38 @@ def create_weights(
block_n = 1
block_k = 32
tp_size = get_tp_group().world_size
# Pad intermediate_size_per_partition to the nearest multiple of
# block_n so that both the CK blockscale kernel alignment constraints
# and the block-quantization scale alignment constraints are satisfied.
# For tp=4, inter_dim=320 is not divisible by block_n=128; padding to
# 384 (=3×128) satisfies both stage1 NPerBlock=128 and block_n=128.
# The weight loader already supports partial loading into a padded
# buffer via the MXFP4 alignment path (_load_w13 / _load_w2).
padded_inter = (
(intermediate_size_per_partition + block_n - 1) // block_n * block_n
)
# NOTE: To ensure proper alignment of the block-wise quantization
# scales, the output_size of the weights for both the gate and up
# layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size_per_partition % block_n != 0:
if padded_inter % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"{padded_inter} is not divisible by "
f"weight quantization block_n = {block_n}."
)
if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
if tp_size > 1 and padded_inter % block_k != 0:
# Required by row parallel
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"{padded_inter} is not divisible by "
f"weight quantization block_k = {block_k}."
)

# WEIGHTS
# Allocated at original (un-padded) size; inter_dim padding is applied
# later in _process_block_quant (after normalize, before shuffle_weights),
# mirroring the BF16 approach in UnquantizedFusedMoEMethod.
w13_weight = atom_parameter(
torch.empty(
num_experts,
Expand Down Expand Up @@ -1654,6 +1710,41 @@ def _process_block_quant(self, layer: nn.Module) -> None:
assert self.quant_config.is_dynamic
self._normalize_weights_and_scales(layer)

# Inter-dim padding for block-quantized FP8 (mirrors BF16 approach in
# UnquantizedFusedMoEMethod.process_weights_after_loading).
# When inter_dim is not a multiple of block_n (e.g. tp=4: 320 % 128 ≠ 0),
# zero-pad both weights to the nearest block_n multiple BEFORE shuffling.
# Padding area is zero so dequant(0, scale) = 0 is numerically safe.
# Scale tensors use ceil(inter/block_n) and are already shape-compatible.
inter_dim = layer.w2_weight.shape[-1]
block_n = 128 if self.quant_type == QuantType.per_1x128 else 32
# FP8 blockscale stage2 requires KPerBlock=128 (gfx950 FP8 mfma KPack=32 constraint
# prevents KPerBlock=64). align must always be block_n(=128) so that inter_pad%128==0.
# Bug fix: previously used align=64 for inter<=192 (copied from BF16 path), but
# 192%128=64!=0 → stage2 kernel dispatch fails. Correct: always align to block_n.
# tp=8 inter=160 → 256 (3×128→no, ceil(160/128)*128=256); tp=4 inter=320 → 384.
align = block_n
inter_pad = (inter_dim + align - 1) // align * align
if inter_pad != inter_dim:
E = layer.w13_weight.shape[0]
hidden = layer.w13_weight.shape[-1]
w13 = layer.w13_weight.data
w13_new = torch.zeros(
E, 2 * inter_pad, hidden,
dtype=w13.dtype, device=w13.device
)
w13_new[:, :inter_dim, :] = w13[:, :inter_dim, :] # gate
w13_new[:, inter_pad:inter_pad + inter_dim, :] = w13[:, inter_dim:, :] # up
layer.w13_weight = atom_parameter(w13_new)

w2 = layer.w2_weight.data
w2_new = torch.zeros(
E, hidden, inter_pad,
dtype=w2.dtype, device=w2.device
)
w2_new[:, :, :inter_dim] = w2
layer.w2_weight = atom_parameter(w2_new)

if not self.need_normalize_e4m3fn_to_e4m3fnuz:
layer.w13_weight = atom_parameter(layer.w13_weight.data)
layer.w13_weight_scale = atom_parameter(layer.w13_weight_scale.data)
Expand Down Expand Up @@ -1726,14 +1817,19 @@ def get_fused_moe_quant_config(
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
)
elif self.quant_type == QuantType.per_1x128:
block_shape = [128, 128]
elif self.quant_type == QuantType.per_1x32:
block_shape = [1, 32]
else:
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=None,
)
block_shape = None
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=block_shape,
)

@mark_trace(prefix="fp8_moe", torch_compile=False)
def apply(
Expand Down Expand Up @@ -2207,10 +2303,36 @@ def _load_w13(
expert_shard_size = expert_data.shape[shard_dim] // 2
# Derive shard size from loaded_weight (unpadded checkpoint) to avoid
# out-of-bounds when expert_data is padded (e.g. MXFP4 alignment).
load_shard_size = loaded_weight.shape[shard_dim] // self.tp_size
loaded_weight = loaded_weight.narrow(
shard_dim, load_shard_size * tp_rank, load_shard_size
)
# Use ceil so that the last partial scale block (e.g. per_1x128 with
# inter=1280 and tp=4: 10 blocks / 4 = 2.5 → ceil=3) is included.
# Without ceil, the 3rd scale block is never copied and stays at the
# torch.ones() initial value of 1.0, causing ~5000× dequant error.
load_shard_size = (
loaded_weight.shape[shard_dim] + self.tp_size - 1
) // self.tp_size
start = load_shard_size * tp_rank
# When D < tp_size (e.g. per_1x128 scale block count smaller than
# tp_size, observed at tp=8 with inter=1280 → D=10), the ceil split
# gives some trailing ranks start >= D so they hold no slice of the
# loaded tensor. Skip narrow + copy_ for those ranks; the rank's
# slice of expert_data stays at its initialised value (0 for weight,
# 1.0 for scale) and the rank contributes a no-op to the column
# gather / row reduction.
if start >= loaded_weight.shape[shard_dim]:
# FP8 scale tensors are torch.ones() initialised. If we leave the
# trailing rank's slice at 1.0, the downstream FP8 dequant multiplies
# the (uninitialised) fp8 weight by 1.0 instead of the correct
# quantization scale, contaminating the column gather / row reduction.
# Zero the slot so dequant produces 0 and the rank contributes a
# true no-op (matches MXFP4 scale init at moe.py:776,813).
if expert_data.dtype == torch.float32:
if shard_id == "w1":
expert_data.narrow(shard_dim, 0, expert_shard_size).zero_()
else:
expert_data.narrow(shard_dim, expert_shard_size, expert_shard_size).zero_()
return
size = min(load_shard_size, loaded_weight.shape[shard_dim] - start)
loaded_weight = loaded_weight.narrow(shard_dim, start, size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
Expand Down Expand Up @@ -2246,10 +2368,24 @@ def _load_w2(
if not load_full:
# Derive shard size from loaded_weight (unpadded checkpoint) to
# avoid out-of-bounds when expert_data is padded (e.g. MXFP4).
load_shard_size = loaded_weight.shape[shard_dim] // self.tp_size
loaded_weight = loaded_weight.narrow(
shard_dim, load_shard_size * tp_rank, load_shard_size
)
# Use ceil (same reason as _load_w13: partial last scale block).
load_shard_size = (
loaded_weight.shape[shard_dim] + self.tp_size - 1
) // self.tp_size
start = load_shard_size * tp_rank
# See _load_w13 comment above: when D < tp_size the ceil split
# leaves trailing ranks with no slice; skip narrow + copy_.
if start >= loaded_weight.shape[shard_dim]:
# Zero the scale slice so dequant=0 instead of multiplying by
# stale init=1.0; see _load_w13 comment for full rationale.
if expert_data.dtype == torch.float32:
if load_shard_size != shard_size:
expert_data.narrow(shard_dim, 0, load_shard_size).zero_()
else:
expert_data.zero_()
return
size = min(load_shard_size, loaded_weight.shape[shard_dim] - start)
loaded_weight = loaded_weight.narrow(shard_dim, start, size)
if load_shard_size != shard_size:
expert_data = expert_data.narrow(shard_dim, 0, load_shard_size)
# w2, down_proj: Load into only logical weight of w2.
Expand Down Expand Up @@ -2556,6 +2692,14 @@ def select_experts(
routed_scaling_factor: float = 1.0,
):

# Model-provided custom routing (e.g. sigmoid + bias + scaling for Step-3.5)
if custom_routing_function is not None:
return custom_routing_function(
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)

# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
Expand Down Expand Up @@ -2697,6 +2841,7 @@ def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor)
hidden_states = naive_multicast(hidden_states, cu_tokens_across_dp_cpu)
router_logits = naive_multicast(router_logits, cu_tokens_across_dp_cpu)


# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
Expand Down
Loading
Loading