From 80142ca6b0bb7d8da531f07f34dbb85ecccdb5fb Mon Sep 17 00:00:00 2001 From: vbaddi Date: Wed, 22 Apr 2026 02:24:04 +0530 Subject: [PATCH 1/7] feat: NSP-blocked MoE prefill dispatch for Qwen3MOE and GPT-OSS Add expert-blocked NSP-parallel prefill forward to QEffPrefillChunkedQwen3MoeSparseMoeBlock and QEffPrefillOnlyChunkedGptOssMLP. Controlled via EXPERT_BLOCKING_NUM_NSP env var. Fix CtxScatterFunc3D/CtxGatherFunc3D eager forward for INT32_MAX sentinel handling. Add disagg-mode tests for both models with tiny configs. Signed-off-by: vbaddi --- QEfficient/customop/ctx_scatter_gather.py | 4 +- .../models/gpt_oss/modeling_gpt_oss.py | 89 ++++++++++++ .../models/qwen3_moe/modeling_qwen3_moe.py | 129 +++++++++++++++-- .../models/test_moe_prefill_blocked.py | 132 ++++++++++++++++++ 4 files changed, 340 insertions(+), 14 deletions(-) create mode 100644 tests/transformers/models/test_moe_prefill_blocked.py diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index 59bfe6af0..bc8775707 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -78,8 +78,9 @@ def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates class CtxScatterFunc3D(torch.autograd.Function): @staticmethod def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): + data = data.clone() batch_idx = torch.arange(data.shape[0]).view(-1, 1) - ctx_idx = position_ids + ctx_idx = torch.where(position_ids == torch.iinfo(torch.int32).max, data.shape[1] - 1, position_ids) data[batch_idx, ctx_idx] = updates return data @@ -103,6 +104,7 @@ class CtxGatherFunc3D(torch.autograd.Function): @staticmethod def forward(data: torch.Tensor, ctx_indices: torch.Tensor): batch_indices = torch.arange(data.shape[0]).view(-1, 1) + ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices) return data[batch_indices, ctx_indices] @staticmethod diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 6f805bfd4..5707dc209 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -36,6 +36,7 @@ generic_blocked_attention_interface, past_key_value_update, ) +from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc3D, CtxScatterFunc3D from QEfficient.transformers.cache_utils import QEffHybridCacheForGPTOSS from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -50,7 +51,91 @@ def __qeff_init__(self): self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) +EXPERT_BLOCKING_NUM_NSP = int(os.environ.get("EXPERT_BLOCKING_NUM_NSP", "16")) + + +def _ctx_scatter_gather_gptoss_expert_blocked( + x: torch.Tensor, + T2Ei: torch.Tensor, + W_g: torch.Tensor, + W_u: torch.Tensor, + W_d: torch.Tensor, + b_g: torch.Tensor, + b_u: torch.Tensor, + b_d: torch.Tensor, + limit: float, + alpha: float, + T: int, +) -> torch.Tensor: + """Packed-prefix expert helper for GPT-OSS NSP-blocked dispatch.""" + batch_size, hidden_size = T2Ei.shape[0], x.shape[1] + scatter_idx = (torch.cumsum(T2Ei.long(), dim=1) - 1).to(torch.int32) + invalid_mask = ~T2Ei + INT32_MAX = torch.tensor(torch.iinfo(torch.int32).max, dtype=torch.int32, device=x.device) + scatter_safe_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) + + x_prime = torch.zeros(batch_size, T, hidden_size, dtype=x.dtype, device=x.device) + x_prime = CtxScatterFunc3D.apply(x_prime, scatter_safe_idx, x.unsqueeze(0).expand(batch_size, -1, -1)) + + valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) + row_range = torch.arange(T, device=x.device, dtype=torch.int32).unsqueeze(0) + valid_output_rows = row_range < valid_rows + x_prime = torch.where(valid_output_rows.unsqueeze(-1), x_prime, torch.zeros_like(x_prime)) + + gate = (x_prime @ W_g) + b_g.unsqueeze(1) + up = (x_prime @ W_u) + b_u.unsqueeze(1) + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate * alpha) + intermediate = (up + 1) * glu + down_prime = (intermediate @ W_d) + b_d.unsqueeze(1) + down_prime = torch.where(valid_output_rows.unsqueeze(-1), down_prime, torch.zeros_like(down_prime)) + + gather_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) + delta_out = CtxGatherFunc3D.apply(down_prime, gather_idx) + delta_out = torch.where(invalid_mask.unsqueeze(-1), torch.zeros_like(delta_out), delta_out) + return delta_out + + class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP): + def __qeff_init__(self): + pass + + def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + T, H = x.shape + num_nsp = EXPERT_BLOCKING_NUM_NSP + E = self.experts.num_experts + if E % num_nsp != 0: + raise ValueError(f"num_experts ({E}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})") + local_experts = E // num_nsp + I = self.experts.gate_proj.shape[2] # noqa: E741 + rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + W_u = self.experts.up_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + W_d = self.experts.down_proj.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() + b_g = self.experts.gate_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() + b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() + b_d = self.experts.down_proj_bias.view(local_experts, num_nsp, H).transpose(0, 1).contiguous() + expert_out_partial = x.new_zeros((num_nsp, T, H)) + for slot in range(local_experts): + routing_weight = rw[:, slot, :].unsqueeze(-1) + T2Ei = routing_weight.squeeze(-1) > 0 + delta = _ctx_scatter_gather_gptoss_expert_blocked( + x=x, + T2Ei=T2Ei, + W_g=W_g[:, slot], + W_u=W_u[:, slot], + W_d=W_d[:, slot], + b_g=b_g[:, slot], + b_u=b_u[:, slot], + b_d=b_d[:, slot], + limit=self.experts.limit, + alpha=self.experts.alpha, + T=T, + ) + expert_out_partial = expert_out_partial + (delta * routing_weight) + return expert_out_partial.sum(dim=0) + def forward(self, hidden: torch.Tensor): B, S, H = hidden.shape T = B * S @@ -69,6 +154,10 @@ def forward(self, hidden: torch.Tensor): # Routing weights for each expert [T, E] routing_weights = masked_logits + if self.experts.num_experts % EXPERT_BLOCKING_NUM_NSP == 0: + expert_out = self._forward_expert_blocked(x=hidden, routing_weights=routing_weights) + return expert_out.view(B, S, H), router_logits + # ────────────────── allocate the output tensor ───── expert_out = hidden.new_zeros((T, H)) # accumulation buffer diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index de92eae8f..415c1c396 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +import os from typing import List, Optional, Tuple, Type import torch @@ -32,6 +33,7 @@ generic_blocked_attention_interface, past_key_value_update, ) +from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc3D, CtxScatterFunc3D from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -100,8 +102,85 @@ def eager_attention_forward( return attn_output, attn_weights +EXPERT_BLOCKING_NUM_NSP = int(os.environ.get("EXPERT_BLOCKING_NUM_NSP", "16")) + + +def _ctx_scatter_gather_expert_blocked( + x: torch.Tensor, + T2Ei: torch.Tensor, + W_g: torch.Tensor, + W_u: torch.Tensor, + W_d: torch.Tensor, + act_fn, + T: int, +) -> torch.Tensor: + """Packed-prefix expert helper for NSP-blocked dispatch.""" + batch_size, hidden_size = T2Ei.shape[0], x.shape[1] + scatter_idx = (torch.cumsum(T2Ei.long(), dim=1) - 1).to(torch.int32) + invalid_mask = ~T2Ei + INT32_MAX = torch.tensor(torch.iinfo(torch.int32).max, dtype=torch.int32, device=x.device) + scatter_safe_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) + + x_prime = torch.zeros(batch_size, T, hidden_size, dtype=x.dtype, device=x.device) + x_prime = CtxScatterFunc3D.apply(x_prime, scatter_safe_idx, x.unsqueeze(0).expand(batch_size, -1, -1)) + + gate_prime = x_prime @ W_g + up_prime = x_prime @ W_u + down_prime = (up_prime * act_fn(gate_prime)) @ W_d + + valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) + row_range = torch.arange(T, device=x.device, dtype=torch.int32).unsqueeze(0) + down_prime = torch.where((row_range < valid_rows).unsqueeze(-1), down_prime, torch.zeros_like(down_prime)) + + gather_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) + delta_out = CtxGatherFunc3D.apply(down_prime, gather_idx) + delta_out = torch.where(invalid_mask.unsqueeze(-1), torch.zeros_like(delta_out), delta_out) + return delta_out + + class QEffPrefillChunkedQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def __qeff_init__(self): + self.gate_proj_w = [] + self.up_proj_w = [] + self.down_proj_w = [] + with torch.no_grad(): + for e in range(self.num_experts): + self.gate_proj_w.append(self.experts[e].gate_proj.weight.T) + self.up_proj_w.append(self.experts[e].up_proj.weight.T) + self.down_proj_w.append(self.experts[e].down_proj.weight.T) + self.gate_proj_w = torch.stack(self.gate_proj_w) + self.up_proj_w = torch.stack(self.up_proj_w) + self.down_proj_w = torch.stack(self.down_proj_w) + + def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + T, H = x.shape + num_nsp = EXPERT_BLOCKING_NUM_NSP + if self.num_experts % num_nsp != 0: + raise ValueError( + f"num_experts ({self.num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})" + ) + local_experts = self.num_experts // num_nsp + rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + W_g = self.gate_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + W_u = self.up_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + W_d = self.down_proj_w.view(local_experts, num_nsp, -1, H).transpose(0, 1).contiguous() + expert_out_partial = x.new_zeros((num_nsp, T, H)) + for slot in range(local_experts): + routing_weight = rw[:, slot, :].unsqueeze(-1) + T2Ei = routing_weight.squeeze(-1) > 0 + delta = _ctx_scatter_gather_expert_blocked( + x=x, + T2Ei=T2Ei, + W_g=W_g[:, slot], + W_u=W_u[:, slot], + W_d=W_d[:, slot], + act_fn=self.experts[0].act_fn, + T=T, + ) + expert_out_partial = expert_out_partial + (delta * routing_weight) + return expert_out_partial.sum(dim=0) + + def orig_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: B, S, H = hidden_states.shape T = B * S x = hidden_states.view(T, H) @@ -113,20 +192,44 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens top_w = top_w.to(hidden_states.dtype) masked_logits = torch.zeros_like(router_logits) masked_logits.scatter_(1, top_i, top_w) - # Routing weights for each expert [T, E] routing_weights = masked_logits - # ────────────────── allocate the output tensor ───── - expert_out = x.new_zeros((T, H)) # accumulation buffer - # ───────────────────────── Expert computation loop ───────────────────────────── + expert_out = x.new_zeros((T, H)) + for e in range(self.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) + W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T + W_d = self.experts[e].down_proj.weight.T + gate = x @ W_g + up = x @ W_u + down = (up * self.experts[e].act_fn(gate)) @ W_d + expert_out += down * routing_weight + return expert_out.view(B, S, H), router_logits + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + B, S, H = hidden_states.shape + T = B * S + x = hidden_states.view(T, H) + router_logits = self.gate(x) + prob = F.softmax(router_logits, -1, dtype=torch.float) + top_w, top_i = torch.topk(prob, self.top_k, -1) + if self.norm_topk_prob: + top_w /= top_w.sum(-1, keepdim=True) + top_w = top_w.to(hidden_states.dtype) + routing_weights = torch.zeros_like(router_logits) + routing_weights.scatter_(1, top_i, top_w) + + if self.num_experts % EXPERT_BLOCKING_NUM_NSP == 0: + expert_out = self._forward_expert_blocked(x=x, routing_weights=routing_weights) + return expert_out.view(B, S, H), router_logits + + expert_out = x.new_zeros((T, H)) for e in range(self.num_experts): - routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] - W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T # [H, I], [H, I] - W_d = self.experts[e].down_proj.weight.T # [I, H] - gate = x @ W_g # [T, I] - up = x @ W_u # [T, I] - down = (up * self.experts[e].act_fn(gate)) @ W_d # [T, H] - masked_down = down * routing_weight - expert_out += masked_down + routing_weight = routing_weights[:, e].unsqueeze(-1) + W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T + W_d = self.experts[e].down_proj.weight.T + gate = x @ W_g + up = x @ W_u + down = (up * self.experts[e].act_fn(gate)) @ W_d + expert_out += down * routing_weight return expert_out.view(B, S, H), router_logits diff --git a/tests/transformers/models/test_moe_prefill_blocked.py b/tests/transformers/models/test_moe_prefill_blocked.py new file mode 100644 index 000000000..ca2297543 --- /dev/null +++ b/tests/transformers/models/test_moe_prefill_blocked.py @@ -0,0 +1,132 @@ +""" +Tests for NSP-blocked MoE prefill dispatch (Qwen3MOE + GPT-OSS). +Uses EXPERT_BLOCKING_NUM_NSP=2 so tests run fast on any num_experts. +Covers: parity, decode export, prefill+chunking export (disagg mode). +""" + +import os + +import torch +from transformers import AutoConfig, AutoModelForCausalLM + +os.environ.setdefault("EXPERT_BLOCKING_NUM_NSP", "2") + +from QEfficient import QEFFAutoModelForCausalLM + +MODEL_KWARGS = {"attn_implementation": "eager"} + +QWEN3_MOE_CFG = dict( + max_position_embeddings=256, + num_hidden_layers=2, + num_attention_heads=4, + hidden_size=128, + intermediate_size=512, + vocab_size=127, + num_key_value_heads=2, +) +GPTOSS_CFG = dict( + max_position_embeddings=256, + num_hidden_layers=2, + num_attention_heads=2, + hidden_size=32, + intermediate_size=32, + vocab_size=127, + num_key_value_heads=2, +) + + +# ── Qwen3MOE ────────────────────────────────────────────────────────────────── + + +def test_qwen3moe_blocked_forward_parity(): + from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import ( + QEffPrefillChunkedQwen3MoeSparseMoeBlock, + ) + + config = AutoConfig.for_model("qwen3_moe", **QWEN3_MOE_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + + blocks = [ + m + for _, m in model.named_modules() + if hasattr(m, "experts") and hasattr(m, "gate") and hasattr(m, "num_experts") + ] + assert blocks + + block = blocks[0] + chunked = QEffPrefillChunkedQwen3MoeSparseMoeBlock.__new__(QEffPrefillChunkedQwen3MoeSparseMoeBlock) + chunked.__dict__.update(block.__dict__) + chunked.__class__ = QEffPrefillChunkedQwen3MoeSparseMoeBlock + chunked.__qeff_init__() + + x = torch.randn(1, 8, config.hidden_size) + with torch.no_grad(): + orig, _ = chunked.orig_forward(x) + blocked, _ = chunked.forward(x) + + assert orig.shape == blocked.shape + assert (orig - blocked).abs().max().item() < 0.1, "Qwen3MOE parity failed" + + +def test_qwen3moe_decode_export(tmp_path): + config = AutoConfig.for_model("qwen3_moe", **QWEN3_MOE_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) + qeff.export(tmp_path / "decode") + assert qeff.onnx_path.is_file() + + +def test_qwen3moe_prefill_chunked_export(tmp_path): + config = AutoConfig.for_model("qwen3_moe", **QWEN3_MOE_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) + qeff.export(tmp_path / "prefill", prefill_only=True, enable_chunking=True) + assert qeff.onnx_path.is_file() + + +# ── GPT-OSS ─────────────────────────────────────────────────────────────────── + + +def test_gptoss_blocked_forward_parity(): + from QEfficient.transformers.models.gpt_oss.modeling_gpt_oss import ( + QEffPrefillOnlyChunkedGptOssMLP, + ) + from QEfficient.transformers.models.pytorch_transforms import PrefillOnlyChunkedTransform + + config = AutoConfig.for_model("gpt_oss", **GPTOSS_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + + blocks_orig = [m for _, m in model.named_modules() if m.__class__.__name__ == "GptOssMLP"] + assert blocks_orig + + x = torch.randn(1, 8, config.hidden_size) + with torch.no_grad(): + orig, _ = blocks_orig[0].forward(x) + + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) + PrefillOnlyChunkedTransform.apply(qeff.model) + + blocks_chunked = [m for _, m in qeff.model.named_modules() if isinstance(m, QEffPrefillOnlyChunkedGptOssMLP)] + assert blocks_chunked + + with torch.no_grad(): + blocked, _ = blocks_chunked[0].forward(x) + + assert orig.shape == blocked.shape + assert (orig - blocked).abs().max().item() < 0.1, "GPT-OSS parity failed" + + +def test_gptoss_decode_export(tmp_path): + config = AutoConfig.for_model("gpt_oss", **GPTOSS_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) + qeff.export(tmp_path / "decode") + assert qeff.onnx_path.is_file() + + +def test_gptoss_prefill_chunked_export(tmp_path): + config = AutoConfig.for_model("gpt_oss", **GPTOSS_CFG) + model = AutoModelForCausalLM.from_config(config, **MODEL_KWARGS) + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=False) + qeff.export(tmp_path / "prefill", prefill_only=True, enable_chunking=True) + assert qeff.onnx_path.is_file() From a5bd93a48ba2bc55fad8538dfb6cae4b34fe9070 Mon Sep 17 00:00:00 2001 From: vbaddi Date: Wed, 22 Apr 2026 03:09:20 +0530 Subject: [PATCH 2/7] nit: weights re-route fixes Signed-off-by: vbaddi --- .../models/gpt_oss/modeling_gpt_oss.py | 39 +++++++++------- .../models/qwen3_moe/modeling_qwen3_moe.py | 46 +++++++++++-------- .../models/test_moe_prefill_blocked.py | 7 +++ 3 files changed, 55 insertions(+), 37 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 5707dc209..5e4d1547a 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -99,23 +99,28 @@ def _ctx_scatter_gather_gptoss_expert_blocked( class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP): def __qeff_init__(self): - pass - - def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: - T, H = x.shape num_nsp = EXPERT_BLOCKING_NUM_NSP E = self.experts.num_experts if E % num_nsp != 0: raise ValueError(f"num_experts ({E}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})") local_experts = E // num_nsp - I = self.experts.gate_proj.shape[2] # noqa: E741 + H = self.experts.hidden_size + I = self.experts.expert_dim # noqa: E741 + with torch.no_grad(): + self._blocked_W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + self._blocked_W_u = self.experts.up_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + self._blocked_W_d = self.experts.down_proj.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() + self._blocked_b_g = self.experts.gate_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() + self._blocked_b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() + self._blocked_b_d = self.experts.down_proj_bias.view(local_experts, num_nsp, H).transpose(0, 1).contiguous() + self._blocked_num_nsp = num_nsp + self._blocked_local_experts = local_experts + + def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + T, H = x.shape + num_nsp = self._blocked_num_nsp + local_experts = self._blocked_local_experts rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() - W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() - W_u = self.experts.up_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() - W_d = self.experts.down_proj.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() - b_g = self.experts.gate_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() - b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() - b_d = self.experts.down_proj_bias.view(local_experts, num_nsp, H).transpose(0, 1).contiguous() expert_out_partial = x.new_zeros((num_nsp, T, H)) for slot in range(local_experts): routing_weight = rw[:, slot, :].unsqueeze(-1) @@ -123,12 +128,12 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor delta = _ctx_scatter_gather_gptoss_expert_blocked( x=x, T2Ei=T2Ei, - W_g=W_g[:, slot], - W_u=W_u[:, slot], - W_d=W_d[:, slot], - b_g=b_g[:, slot], - b_u=b_u[:, slot], - b_d=b_d[:, slot], + W_g=self._blocked_W_g[:, slot], + W_u=self._blocked_W_u[:, slot], + W_d=self._blocked_W_d[:, slot], + b_g=self._blocked_b_g[:, slot], + b_u=self._blocked_b_u[:, slot], + b_d=self._blocked_b_d[:, slot], limit=self.experts.limit, alpha=self.experts.alpha, T=T, diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 415c1c396..13122dfb1 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -140,30 +140,36 @@ def _ctx_scatter_gather_expert_blocked( class QEffPrefillChunkedQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): def __qeff_init__(self): - self.gate_proj_w = [] - self.up_proj_w = [] - self.down_proj_w = [] - with torch.no_grad(): - for e in range(self.num_experts): - self.gate_proj_w.append(self.experts[e].gate_proj.weight.T) - self.up_proj_w.append(self.experts[e].up_proj.weight.T) - self.down_proj_w.append(self.experts[e].down_proj.weight.T) - self.gate_proj_w = torch.stack(self.gate_proj_w) - self.up_proj_w = torch.stack(self.up_proj_w) - self.down_proj_w = torch.stack(self.down_proj_w) - - def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: - T, H = x.shape num_nsp = EXPERT_BLOCKING_NUM_NSP if self.num_experts % num_nsp != 0: raise ValueError( f"num_experts ({self.num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})" ) local_experts = self.num_experts // num_nsp + gate_proj_w = [] + up_proj_w = [] + down_proj_w = [] + with torch.no_grad(): + for e in range(self.num_experts): + gate_proj_w.append(self.experts[e].gate_proj.weight.T) + up_proj_w.append(self.experts[e].up_proj.weight.T) + down_proj_w.append(self.experts[e].down_proj.weight.T) + stacked_g = torch.stack(gate_proj_w) # [E, H, I] + stacked_u = torch.stack(up_proj_w) + stacked_d = torch.stack(down_proj_w) # [E, I, H] + H = stacked_g.shape[1] + I = stacked_g.shape[2] # noqa: E741 + self._blocked_W_g = stacked_g.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + self._blocked_W_u = stacked_u.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + self._blocked_W_d = stacked_d.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() + self._blocked_num_nsp = num_nsp + self._blocked_local_experts = local_experts + + def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + T, H = x.shape + num_nsp = self._blocked_num_nsp + local_experts = self._blocked_local_experts rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() - W_g = self.gate_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() - W_u = self.up_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() - W_d = self.down_proj_w.view(local_experts, num_nsp, -1, H).transpose(0, 1).contiguous() expert_out_partial = x.new_zeros((num_nsp, T, H)) for slot in range(local_experts): routing_weight = rw[:, slot, :].unsqueeze(-1) @@ -171,9 +177,9 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor delta = _ctx_scatter_gather_expert_blocked( x=x, T2Ei=T2Ei, - W_g=W_g[:, slot], - W_u=W_u[:, slot], - W_d=W_d[:, slot], + W_g=self._blocked_W_g[:, slot], + W_u=self._blocked_W_u[:, slot], + W_d=self._blocked_W_d[:, slot], act_fn=self.experts[0].act_fn, T=T, ) diff --git a/tests/transformers/models/test_moe_prefill_blocked.py b/tests/transformers/models/test_moe_prefill_blocked.py index ca2297543..f7789707e 100644 --- a/tests/transformers/models/test_moe_prefill_blocked.py +++ b/tests/transformers/models/test_moe_prefill_blocked.py @@ -1,3 +1,10 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + """ Tests for NSP-blocked MoE prefill dispatch (Qwen3MOE + GPT-OSS). Uses EXPERT_BLOCKING_NUM_NSP=2 so tests run fast on any num_experts. From c4ef4c847b37ee77cee8453bb4cbdbebc7149754 Mon Sep 17 00:00:00 2001 From: vbaddi Date: Wed, 22 Apr 2026 03:21:45 +0530 Subject: [PATCH 3/7] nit: weights re-route fixes v1 Signed-off-by: vbaddi --- .../models/gpt_oss/modeling_gpt_oss.py | 36 ++++++--------- .../models/qwen3_moe/modeling_qwen3_moe.py | 46 ++++++++----------- 2 files changed, 34 insertions(+), 48 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 5e4d1547a..84eb4acac 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -98,29 +98,21 @@ def _ctx_scatter_gather_gptoss_expert_blocked( class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP): - def __qeff_init__(self): + def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + T, H = x.shape num_nsp = EXPERT_BLOCKING_NUM_NSP E = self.experts.num_experts if E % num_nsp != 0: raise ValueError(f"num_experts ({E}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})") local_experts = E // num_nsp - H = self.experts.hidden_size I = self.experts.expert_dim # noqa: E741 - with torch.no_grad(): - self._blocked_W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() - self._blocked_W_u = self.experts.up_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() - self._blocked_W_d = self.experts.down_proj.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() - self._blocked_b_g = self.experts.gate_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() - self._blocked_b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() - self._blocked_b_d = self.experts.down_proj_bias.view(local_experts, num_nsp, H).transpose(0, 1).contiguous() - self._blocked_num_nsp = num_nsp - self._blocked_local_experts = local_experts - - def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: - T, H = x.shape - num_nsp = self._blocked_num_nsp - local_experts = self._blocked_local_experts rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + W_u = self.experts.up_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() + W_d = self.experts.down_proj.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() + b_g = self.experts.gate_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() + b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() + b_d = self.experts.down_proj_bias.view(local_experts, num_nsp, H).transpose(0, 1).contiguous() expert_out_partial = x.new_zeros((num_nsp, T, H)) for slot in range(local_experts): routing_weight = rw[:, slot, :].unsqueeze(-1) @@ -128,12 +120,12 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor delta = _ctx_scatter_gather_gptoss_expert_blocked( x=x, T2Ei=T2Ei, - W_g=self._blocked_W_g[:, slot], - W_u=self._blocked_W_u[:, slot], - W_d=self._blocked_W_d[:, slot], - b_g=self._blocked_b_g[:, slot], - b_u=self._blocked_b_u[:, slot], - b_d=self._blocked_b_d[:, slot], + W_g=W_g[:, slot], + W_u=W_u[:, slot], + W_d=W_d[:, slot], + b_g=b_g[:, slot], + b_u=b_u[:, slot], + b_d=b_d[:, slot], limit=self.experts.limit, alpha=self.experts.alpha, T=T, diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 13122dfb1..e233e0e83 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -140,36 +140,30 @@ def _ctx_scatter_gather_expert_blocked( class QEffPrefillChunkedQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): def __qeff_init__(self): + self.gate_proj_w = [] + self.up_proj_w = [] + self.down_proj_w = [] + with torch.no_grad(): + for e in range(self.num_experts): + self.gate_proj_w.append(self.experts[e].gate_proj.weight.T) + self.up_proj_w.append(self.experts[e].up_proj.weight.T) + self.down_proj_w.append(self.experts[e].down_proj.weight.T) + self.gate_proj_w = torch.stack(self.gate_proj_w) # [E, H, I] + self.up_proj_w = torch.stack(self.up_proj_w) # [E, H, I] + self.down_proj_w = torch.stack(self.down_proj_w) # [E, I, H] + + def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + T, H = x.shape num_nsp = EXPERT_BLOCKING_NUM_NSP if self.num_experts % num_nsp != 0: raise ValueError( f"num_experts ({self.num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})" ) local_experts = self.num_experts // num_nsp - gate_proj_w = [] - up_proj_w = [] - down_proj_w = [] - with torch.no_grad(): - for e in range(self.num_experts): - gate_proj_w.append(self.experts[e].gate_proj.weight.T) - up_proj_w.append(self.experts[e].up_proj.weight.T) - down_proj_w.append(self.experts[e].down_proj.weight.T) - stacked_g = torch.stack(gate_proj_w) # [E, H, I] - stacked_u = torch.stack(up_proj_w) - stacked_d = torch.stack(down_proj_w) # [E, I, H] - H = stacked_g.shape[1] - I = stacked_g.shape[2] # noqa: E741 - self._blocked_W_g = stacked_g.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() - self._blocked_W_u = stacked_u.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() - self._blocked_W_d = stacked_d.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() - self._blocked_num_nsp = num_nsp - self._blocked_local_experts = local_experts - - def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: - T, H = x.shape - num_nsp = self._blocked_num_nsp - local_experts = self._blocked_local_experts rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + W_g = self.gate_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + W_u = self.up_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() + W_d = self.down_proj_w.view(local_experts, num_nsp, -1, H).transpose(0, 1).contiguous() expert_out_partial = x.new_zeros((num_nsp, T, H)) for slot in range(local_experts): routing_weight = rw[:, slot, :].unsqueeze(-1) @@ -177,9 +171,9 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor delta = _ctx_scatter_gather_expert_blocked( x=x, T2Ei=T2Ei, - W_g=self._blocked_W_g[:, slot], - W_u=self._blocked_W_u[:, slot], - W_d=self._blocked_W_d[:, slot], + W_g=W_g[:, slot], + W_u=W_u[:, slot], + W_d=W_d[:, slot], act_fn=self.experts[0].act_fn, T=T, ) From 290839e6855a52ba4f921eeff734b95615c11dfe Mon Sep 17 00:00:00 2001 From: vbaddi Date: Thu, 23 Apr 2026 22:25:50 +0530 Subject: [PATCH 4/7] nit(0423): gpt oss moe fixed and nit Signed-off-by: vbaddi --- QEfficient/base/modeling_qeff.py | 4 +- QEfficient/transformers/modeling_utils.py | 2 +- .../models/gpt_oss/modeling_gpt_oss.py | 50 ++++++++++--------- .../qwen3moe_disagg_mode_with_chunking.py | 9 ++-- 4 files changed, 35 insertions(+), 30 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index e9213761d..a091de749 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -530,8 +530,8 @@ def _compile( onnx_path = Path( onnx_path if onnx_path - else self.onnx_path - if self.onnx_path + # else self.onnx_path + # if self.onnx_path else self.get_onnx_path( prefill_only, enable_chunking, diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index f9d7fe62c..183c19f6f 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -196,7 +196,7 @@ DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"} # This is for supporting different modelling classes specially written for prefill-only model -SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss", "kimi_k2", "kimi_k25"} +SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss", "qwen3_moe", "kimi_k2", "kimi_k25"} _PROXY_ONLY_ONNX_TRANSFORMS = (FP16ClipTransform, SplitTensorsTransform) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 84eb4acac..1248e20ba 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -67,12 +67,11 @@ def _ctx_scatter_gather_gptoss_expert_blocked( alpha: float, T: int, ) -> torch.Tensor: - """Packed-prefix expert helper for GPT-OSS NSP-blocked dispatch.""" batch_size, hidden_size = T2Ei.shape[0], x.shape[1] scatter_idx = (torch.cumsum(T2Ei.long(), dim=1) - 1).to(torch.int32) invalid_mask = ~T2Ei - INT32_MAX = torch.tensor(torch.iinfo(torch.int32).max, dtype=torch.int32, device=x.device) - scatter_safe_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) + int32_max = torch.tensor(torch.iinfo(torch.int32).max, dtype=torch.int32, device=x.device) + scatter_safe_idx = torch.where(invalid_mask, int32_max, scatter_idx) x_prime = torch.zeros(batch_size, T, hidden_size, dtype=x.dtype, device=x.device) x_prime = CtxScatterFunc3D.apply(x_prime, scatter_safe_idx, x.unsqueeze(0).expand(batch_size, -1, -1)) @@ -91,7 +90,7 @@ def _ctx_scatter_gather_gptoss_expert_blocked( down_prime = (intermediate @ W_d) + b_d.unsqueeze(1) down_prime = torch.where(valid_output_rows.unsqueeze(-1), down_prime, torch.zeros_like(down_prime)) - gather_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) + gather_idx = torch.where(invalid_mask, int32_max, scatter_idx) delta_out = CtxGatherFunc3D.apply(down_prime, gather_idx) delta_out = torch.where(invalid_mask.unsqueeze(-1), torch.zeros_like(delta_out), delta_out) return delta_out @@ -101,36 +100,41 @@ class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP): def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: T, H = x.shape num_nsp = EXPERT_BLOCKING_NUM_NSP - E = self.experts.num_experts - if E % num_nsp != 0: - raise ValueError(f"num_experts ({E}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})") - local_experts = E // num_nsp - I = self.experts.expert_dim # noqa: E741 - rw = routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() - W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() - W_u = self.experts.up_proj.view(local_experts, num_nsp, H, I).transpose(0, 1).contiguous() - W_d = self.experts.down_proj.view(local_experts, num_nsp, I, H).transpose(0, 1).contiguous() - b_g = self.experts.gate_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() - b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, I).transpose(0, 1).contiguous() + num_experts = self.experts.num_experts + if num_experts % num_nsp != 0: + raise ValueError(f"num_experts ({num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})") + + local_experts = num_experts // num_nsp + expert_dim = self.experts.expert_dim + routing_weights_by_expert = ( + routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous() + ) + W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, expert_dim).transpose(0, 1).contiguous() + W_u = self.experts.up_proj.view(local_experts, num_nsp, H, expert_dim).transpose(0, 1).contiguous() + W_d = self.experts.down_proj.view(local_experts, num_nsp, expert_dim, H).transpose(0, 1).contiguous() + b_g = self.experts.gate_proj_bias.view(local_experts, num_nsp, expert_dim).transpose(0, 1).contiguous() + b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, expert_dim).transpose(0, 1).contiguous() b_d = self.experts.down_proj_bias.view(local_experts, num_nsp, H).transpose(0, 1).contiguous() + expert_out_partial = x.new_zeros((num_nsp, T, H)) - for slot in range(local_experts): - routing_weight = rw[:, slot, :].unsqueeze(-1) + for local_slot in range(local_experts): + routing_weight = routing_weights_by_expert[:, local_slot, :].unsqueeze(-1) T2Ei = routing_weight.squeeze(-1) > 0 delta = _ctx_scatter_gather_gptoss_expert_blocked( x=x, T2Ei=T2Ei, - W_g=W_g[:, slot], - W_u=W_u[:, slot], - W_d=W_d[:, slot], - b_g=b_g[:, slot], - b_u=b_u[:, slot], - b_d=b_d[:, slot], + W_g=W_g[:, local_slot], + W_u=W_u[:, local_slot], + W_d=W_d[:, local_slot], + b_g=b_g[:, local_slot], + b_u=b_u[:, local_slot], + b_d=b_d[:, local_slot], limit=self.experts.limit, alpha=self.experts.alpha, T=T, ) expert_out_partial = expert_out_partial + (delta * routing_weight) + return expert_out_partial.sum(dim=0) def forward(self, hidden: torch.Tensor): diff --git a/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py b/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py index 655de4ef5..3bc933909 100644 --- a/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py +++ b/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py @@ -14,14 +14,15 @@ from QEfficient import QEFFAutoModelForCausalLM from QEfficient.generation.cloud_infer import QAICInferenceSession -model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507" # weights are not required to convert to fp32 +# model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507" # weights are not required to convert to fp32 +model_id = "yujiepan/qwen3-moe-tiny-random" prompt = """ Explain quantum computing in simple terms. """ config = AutoConfig.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) -PREFILL_SEQ_LEN = 128 -CTX_LEN = 128 * 3 +PREFILL_SEQ_LEN = 256 +CTX_LEN = PREFILL_SEQ_LEN * 3 qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) decode_qpc_path = qeff_model.compile( @@ -48,7 +49,7 @@ num_cores=16, mxfp6_matmul=True, mxint8_kv_cache=True, - num_devices=2, + num_devices=1, split_retained_state_io=True, mos=1, aic_enable_depth_first=True, From 28048519ae21c8fa106af11562a82322acda91aa Mon Sep 17 00:00:00 2001 From: vbaddi Date: Fri, 24 Apr 2026 19:27:10 +0530 Subject: [PATCH 5/7] nit(0424): ctx batch idx cast to int32 Signed-off-by: vbaddi --- QEfficient/customop/ctx_scatter_gather.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index bc8775707..4f46791af 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -69,6 +69,9 @@ def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates # Create indices batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2]), exp_shape) + + # keep index tensor types aligned for backend that require exact dtype match + batch_idx = ops.Cast(batch_idx, to=onnxscript.INT32.dtype) ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [2]), exp_shape) indices = ops.Concat(batch_idx, ctx_idx, axis=2) From 6b049bcd883c982a8638838194a41155fde3e716 Mon Sep 17 00:00:00 2001 From: vbaddi Date: Thu, 30 Apr 2026 07:11:10 +0530 Subject: [PATCH 6/7] nit(0429): qwen3_moe, gpt_oss: port cumsum scatter-gather-update MoE prefill Signed-off-by: vbaddi --- QEfficient/customop/__init__.py | 6 + QEfficient/customop/ctx_scatter_gather.py | 93 +++++++++++++ .../models/gpt_oss/modeling_gpt_oss.py | 122 +++++++++++++----- .../models/qwen3_moe/modeling_qwen3_moe.py | 112 ++++++++++++---- 4 files changed, 276 insertions(+), 57 deletions(-) diff --git a/QEfficient/customop/__init__.py b/QEfficient/customop/__init__.py index 35830aa91..4830e660c 100644 --- a/QEfficient/customop/__init__.py +++ b/QEfficient/customop/__init__.py @@ -8,9 +8,12 @@ from QEfficient.customop.ctx_scatter_gather import ( CtxGatherFunc, CtxGatherFunc3D, + CtxGatherFunc3DGeneralized, CtxGatherFuncBlockedKV, CtxScatterFunc, CtxScatterFunc3D, + CtxScatterFunc3DGeneralized, + CtxScatterFunc3DInt, ) from QEfficient.customop.ctx_scatter_gather_cb import ( CtxGatherFuncBlockedKVCB, @@ -26,7 +29,10 @@ "CtxGatherFuncBlockedKV", "CtxScatterFunc", "CtxGatherFunc3D", + "CtxGatherFunc3DGeneralized", "CtxScatterFunc3D", + "CtxScatterFunc3DGeneralized", + "CtxScatterFunc3DInt", "CustomRMSNormAIC", "GemmaCustomRMSNormAIC", "CtxGatherFuncCB", diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index 4f46791af..19f60886d 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -96,6 +96,74 @@ def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updat return g.onnxscript_op(CtxScatter3D, data, position_ids, updates).setTypeAs(data) +class CtxScatterFunc3DGeneralized(torch.autograd.Function): + """Scatter variant that preserves ``data`` at invalid (INT32_MAX) positions. + + Unlike :class:`CtxScatterFunc3D`, which writes updates for invalid rows to + ``data.shape[1]-1`` (potentially clobbering valid content), this version + masks out invalid rows before scattering so ``data`` is left untouched where + ``position_ids == INT32_MAX``. + """ + + @staticmethod + def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): + data = data.clone() + valid = position_ids != torch.iinfo(torch.int32).max + batch_idx = torch.arange(data.shape[0], device=data.device).view(-1, 1).expand_as(position_ids) + data[batch_idx[valid], position_ids[valid].long()] = updates[valid] + return data + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxScatter3D, data, position_ids, updates).setTypeAs(data) + + +@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) +def CtxScatter3DInt( + data: onnxscript.INT32, position_ids: onnxscript.INT32, updates: onnxscript.INT32 +) -> onnxscript.INT32: + # Find dims + batch_size = ops.Gather(ops.Shape(data), [0]) + seq_len = ops.Gather(ops.Shape(position_ids), [1]) + + # Expanded shape to create indices + zero = ops.Constant(value_ints=[0]) + one = ops.Constant(value_ints=[1]) + exp_shape = ops.Concat(batch_size, seq_len, one, axis=0) + + # Create indices + batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2]), exp_shape) + batch_idx = ops.Cast(batch_idx, to=onnxscript.INT32.dtype) + ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [2]), exp_shape) + indices = ops.Concat(batch_idx, ctx_idx, axis=2) + + return ops.ScatterND(data, indices, updates) + + +class CtxScatterFunc3DInt(torch.autograd.Function): + """Int32-typed scatter used to build a packed->original index table.""" + + @staticmethod + def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): + data = data.clone() + valid = position_ids != torch.iinfo(torch.int32).max + batch_idx = torch.arange(data.shape[0], device=data.device).view(-1, 1).expand_as(position_ids) + data[batch_idx[valid], position_ids[valid].long()] = updates[valid] + return data + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxScatter3DInt, data, position_ids, updates).setTypeAs(data) + + @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGather3D(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[2], axes=[0])) @@ -119,6 +187,31 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor return g.onnxscript_op(CtxGather3D, data, ctx_indices).setTypeAs(data) +class CtxGatherFunc3DGeneralized(torch.autograd.Function): + """Gather variant that tolerates INT32_MAX indices (invalid rows read from 0). + + Semantically equivalent to :class:`CtxGatherFunc3D` on the PyTorch side but + exposed as a separate autograd op so callers using the packed/cumsum scatter + pipeline can be easily recognized and so the ONNX symbolic omits + ``setTypeAs`` (needed when the caller already has a matching dtype on + ``data`` and wants the op signature to flow through without dtype pinning). + """ + + @staticmethod + def forward(data: torch.Tensor, ctx_indices: torch.Tensor): + batch_indices = torch.arange(data.shape[0]).view(-1, 1) + ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices) + return data[batch_indices, ctx_indices] + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxGather3D, data, ctx_indices) + + @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGather( data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32 diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 1248e20ba..5e0270b7b 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -36,7 +36,11 @@ generic_blocked_attention_interface, past_key_value_update, ) -from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc3D, CtxScatterFunc3D +from QEfficient.customop.ctx_scatter_gather import ( + CtxGatherFunc3DGeneralized, + CtxScatterFunc3DGeneralized, + CtxScatterFunc3DInt, +) from QEfficient.transformers.cache_utils import QEffHybridCacheForGPTOSS from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -52,9 +56,36 @@ def __qeff_init__(self): EXPERT_BLOCKING_NUM_NSP = int(os.environ.get("EXPERT_BLOCKING_NUM_NSP", "16")) +EXPERT_BLOCKING_PACKED_CHUNK_SIZE = int(os.environ.get("EXPERT_BLOCKING_PACKED_CHUNK_SIZE", "256")) + +def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: + """Build packed->original token index table for an NSP-sliced expert mask. -def _ctx_scatter_gather_gptoss_expert_blocked( + Given ``T2Ei`` of shape ``[num_nsp, T]`` marking which tokens are routed to + an expert, produces an index tensor where ``matched_idx[b, j]`` is the + original token position in ``x`` that lands at packed position ``j`` for + NSP lane ``b`` (or ``INT32_MAX`` when ``j`` is past the last valid row). + """ + batch_size, seq_len = T2Ei.shape + int32_max = torch.iinfo(torch.int32).max + int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device) + token_idx = torch.arange(seq_len, dtype=torch.int32, device=T2Ei.device).unsqueeze(0).expand(batch_size, -1) + valid_prefix = torch.cumsum(T2Ei.to(torch.int32), dim=1) + valid_dest = valid_prefix - 1 + scatter_pos = torch.where(T2Ei, valid_dest, int32_max_scalar) + # Once the compiler fix for ConstantOfShape(INT32_MAX) is available, this + # can be switched back to ``torch.full_like(token_idx, int32_max)``. + matched_idx = int32_max_scalar.expand_as(token_idx) + matched_idx = CtxScatterFunc3DInt.apply( + matched_idx.unsqueeze(-1), + scatter_pos, + token_idx.unsqueeze(-1), + ).squeeze(-1) + return matched_idx + + +def _cumsum_scatter_gather_update_gptoss_expert_blocked( x: torch.Tensor, T2Ei: torch.Tensor, W_g: torch.Tensor, @@ -63,37 +94,64 @@ def _ctx_scatter_gather_gptoss_expert_blocked( b_g: torch.Tensor, b_u: torch.Tensor, b_d: torch.Tensor, + routing_weight: torch.Tensor, + expert_out: torch.Tensor, limit: float, alpha: float, T: int, + packed_chunk_size: int, ) -> torch.Tensor: - batch_size, hidden_size = T2Ei.shape[0], x.shape[1] - scatter_idx = (torch.cumsum(T2Ei.long(), dim=1) - 1).to(torch.int32) - invalid_mask = ~T2Ei - int32_max = torch.tensor(torch.iinfo(torch.int32).max, dtype=torch.int32, device=x.device) - scatter_safe_idx = torch.where(invalid_mask, int32_max, scatter_idx) - - x_prime = torch.zeros(batch_size, T, hidden_size, dtype=x.dtype, device=x.device) - x_prime = CtxScatterFunc3D.apply(x_prime, scatter_safe_idx, x.unsqueeze(0).expand(batch_size, -1, -1)) + """Cumsum-scatter-gather-update expert helper for GPT-OSS NSP-blocked dispatch. + + Same algorithm as the Qwen3-MOE version but with GPT-OSS biases and GLU + activation (clamped gate/up, ``(up + 1) * gate * sigmoid(gate * alpha)``). + + Shapes: + x : [T, H] + T2Ei : [num_nsp, T] (bool) + W_g, W_u : [num_nsp, H, I] + W_d : [num_nsp, I, H] + b_g, b_u : [num_nsp, I] + b_d : [num_nsp, H] + routing_weight : [num_nsp, T] + expert_out : [num_nsp, T, H] (accumulator, in-out) + """ + batch_size, seq_len = T2Ei.shape + packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) + matched_idx = _build_matched_idx_from_cumsum(T2Ei) valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) - row_range = torch.arange(T, device=x.device, dtype=torch.int32).unsqueeze(0) - valid_output_rows = row_range < valid_rows - x_prime = torch.where(valid_output_rows.unsqueeze(-1), x_prime, torch.zeros_like(x_prime)) + row_range = torch.arange(packed_chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) + x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) + rw_expanded = routing_weight.unsqueeze(-1) + + for packed_start in range(0, seq_len, packed_chunk_size): + packed_stop = packed_start + packed_chunk_size + chunk_matched_idx = matched_idx[:, packed_start:packed_stop] - gate = (x_prime @ W_g) + b_g.unsqueeze(1) - up = (x_prime @ W_u) + b_u.unsqueeze(1) - gate = gate.clamp(min=torch.finfo(torch.float16).min, max=limit) - up = up.clamp(min=-limit, max=limit) - glu = gate * torch.sigmoid(gate * alpha) - intermediate = (up + 1) * glu - down_prime = (intermediate @ W_d) + b_d.unsqueeze(1) - down_prime = torch.where(valid_output_rows.unsqueeze(-1), down_prime, torch.zeros_like(down_prime)) + x_chunk = CtxGatherFunc3DGeneralized.apply(x_expanded, chunk_matched_idx) + + gate = (x_chunk @ W_g) + b_g.unsqueeze(1) + up = (x_chunk @ W_u) + b_u.unsqueeze(1) + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate * alpha) + intermediate = (up + 1) * glu + down_chunk = (intermediate @ W_d) + b_d.unsqueeze(1) + + rw_chunk = CtxGatherFunc3DGeneralized.apply(rw_expanded, chunk_matched_idx) + down_chunk = down_chunk * rw_chunk + + expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) + updated_chunk = expert_out_chunk + down_chunk + + chunk_valid_rows = torch.clamp(valid_rows - packed_start, min=0, max=packed_chunk_size) + updated_chunk = torch.where( + (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) + ) + expert_out = CtxScatterFunc3DGeneralized.apply(expert_out, chunk_matched_idx, updated_chunk) - gather_idx = torch.where(invalid_mask, int32_max, scatter_idx) - delta_out = CtxGatherFunc3D.apply(down_prime, gather_idx) - delta_out = torch.where(invalid_mask.unsqueeze(-1), torch.zeros_like(delta_out), delta_out) - return delta_out + return expert_out class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP): @@ -116,11 +174,11 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, expert_dim).transpose(0, 1).contiguous() b_d = self.experts.down_proj_bias.view(local_experts, num_nsp, H).transpose(0, 1).contiguous() - expert_out_partial = x.new_zeros((num_nsp, T, H)) + expert_out = x.new_zeros((num_nsp, T, H)) for local_slot in range(local_experts): - routing_weight = routing_weights_by_expert[:, local_slot, :].unsqueeze(-1) - T2Ei = routing_weight.squeeze(-1) > 0 - delta = _ctx_scatter_gather_gptoss_expert_blocked( + routing_weight = routing_weights_by_expert[:, local_slot, :] + T2Ei = routing_weight > 0 + expert_out = _cumsum_scatter_gather_update_gptoss_expert_blocked( x=x, T2Ei=T2Ei, W_g=W_g[:, local_slot], @@ -129,13 +187,15 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor b_g=b_g[:, local_slot], b_u=b_u[:, local_slot], b_d=b_d[:, local_slot], + routing_weight=routing_weight, + expert_out=expert_out, limit=self.experts.limit, alpha=self.experts.alpha, T=T, + packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, ) - expert_out_partial = expert_out_partial + (delta * routing_weight) - return expert_out_partial.sum(dim=0) + return expert_out.sum(dim=0) def forward(self, hidden: torch.Tensor): B, S, H = hidden.shape diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index e233e0e83..939d8faa9 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -33,7 +33,11 @@ generic_blocked_attention_interface, past_key_value_update, ) -from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc3D, CtxScatterFunc3D +from QEfficient.customop.ctx_scatter_gather import ( + CtxGatherFunc3DGeneralized, + CtxScatterFunc3DGeneralized, + CtxScatterFunc3DInt, +) from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE @@ -103,39 +107,93 @@ def eager_attention_forward( EXPERT_BLOCKING_NUM_NSP = int(os.environ.get("EXPERT_BLOCKING_NUM_NSP", "16")) +EXPERT_BLOCKING_PACKED_CHUNK_SIZE = int(os.environ.get("EXPERT_BLOCKING_PACKED_CHUNK_SIZE", "256")) + +def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: + """Build packed->original token index table for an NSP-sliced expert mask. -def _ctx_scatter_gather_expert_blocked( + Given ``T2Ei`` of shape ``[num_nsp, T]`` marking which tokens are routed to + an expert, produces an index tensor where ``matched_idx[b, j]`` is the + original token position in ``x`` that lands at packed position ``j`` for + NSP lane ``b`` (or ``INT32_MAX`` when ``j`` is past the last valid row). + """ + batch_size, seq_len = T2Ei.shape + int32_max = torch.iinfo(torch.int32).max + int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device) + token_idx = torch.arange(seq_len, dtype=torch.int32, device=T2Ei.device).unsqueeze(0).expand(batch_size, -1) + valid_prefix = torch.cumsum(T2Ei.to(torch.int32), dim=1) + valid_dest = valid_prefix - 1 + scatter_pos = torch.where(T2Ei, valid_dest, int32_max_scalar) + # Once the compiler fix for ConstantOfShape(INT32_MAX) is available, this + # can be switched back to ``torch.full_like(token_idx, int32_max)``. + matched_idx = int32_max_scalar.expand_as(token_idx) + matched_idx = CtxScatterFunc3DInt.apply( + matched_idx.unsqueeze(-1), + scatter_pos, + token_idx.unsqueeze(-1), + ).squeeze(-1) + return matched_idx + + +def _cumsum_scatter_gather_update_expert_blocked( x: torch.Tensor, T2Ei: torch.Tensor, W_g: torch.Tensor, W_u: torch.Tensor, W_d: torch.Tensor, + routing_weight: torch.Tensor, + expert_out: torch.Tensor, act_fn, T: int, + packed_chunk_size: int, ) -> torch.Tensor: - """Packed-prefix expert helper for NSP-blocked dispatch.""" - batch_size, hidden_size = T2Ei.shape[0], x.shape[1] - scatter_idx = (torch.cumsum(T2Ei.long(), dim=1) - 1).to(torch.int32) - invalid_mask = ~T2Ei - INT32_MAX = torch.tensor(torch.iinfo(torch.int32).max, dtype=torch.int32, device=x.device) - scatter_safe_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) + """Cumsum-scatter-gather-update expert helper for NSP-blocked dispatch. + + Accumulates one local expert's contribution in-place onto ``expert_out``. + Uses a packed/cumsum layout so the MLP runs only over active rows, then + scatters the weighted output back to original token positions. + + Shapes: + x : [T, H] + T2Ei : [num_nsp, T] (bool) + W_g, W_u : [num_nsp, H, I] + W_d : [num_nsp, I, H] + routing_weight : [num_nsp, T] + expert_out : [num_nsp, T, H] (accumulator, in-out) + """ + batch_size, seq_len = T2Ei.shape + packed_chunk_size = max(1, min(packed_chunk_size, seq_len)) - x_prime = torch.zeros(batch_size, T, hidden_size, dtype=x.dtype, device=x.device) - x_prime = CtxScatterFunc3D.apply(x_prime, scatter_safe_idx, x.unsqueeze(0).expand(batch_size, -1, -1)) + matched_idx = _build_matched_idx_from_cumsum(T2Ei) + valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) + row_range = torch.arange(packed_chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0) + x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1) + rw_expanded = routing_weight.unsqueeze(-1) - gate_prime = x_prime @ W_g - up_prime = x_prime @ W_u - down_prime = (up_prime * act_fn(gate_prime)) @ W_d + for packed_start in range(0, seq_len, packed_chunk_size): + packed_stop = packed_start + packed_chunk_size + chunk_matched_idx = matched_idx[:, packed_start:packed_stop] - valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True) - row_range = torch.arange(T, device=x.device, dtype=torch.int32).unsqueeze(0) - down_prime = torch.where((row_range < valid_rows).unsqueeze(-1), down_prime, torch.zeros_like(down_prime)) + x_chunk = CtxGatherFunc3DGeneralized.apply(x_expanded, chunk_matched_idx) + + gate_prime = x_chunk @ W_g + up_prime = x_chunk @ W_u + down_chunk = (up_prime * act_fn(gate_prime)) @ W_d + + rw_chunk = CtxGatherFunc3DGeneralized.apply(rw_expanded, chunk_matched_idx) + down_chunk = down_chunk * rw_chunk + + expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx) + updated_chunk = expert_out_chunk + down_chunk + + chunk_valid_rows = torch.clamp(valid_rows - packed_start, min=0, max=packed_chunk_size) + updated_chunk = torch.where( + (row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk) + ) + expert_out = CtxScatterFunc3DGeneralized.apply(expert_out, chunk_matched_idx, updated_chunk) - gather_idx = torch.where(invalid_mask, INT32_MAX, scatter_idx) - delta_out = CtxGatherFunc3D.apply(down_prime, gather_idx) - delta_out = torch.where(invalid_mask.unsqueeze(-1), torch.zeros_like(delta_out), delta_out) - return delta_out + return expert_out class QEffPrefillChunkedQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): @@ -164,21 +222,23 @@ def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor W_g = self.gate_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() W_u = self.up_proj_w.view(local_experts, num_nsp, H, -1).transpose(0, 1).contiguous() W_d = self.down_proj_w.view(local_experts, num_nsp, -1, H).transpose(0, 1).contiguous() - expert_out_partial = x.new_zeros((num_nsp, T, H)) + expert_out = x.new_zeros((num_nsp, T, H)) for slot in range(local_experts): - routing_weight = rw[:, slot, :].unsqueeze(-1) - T2Ei = routing_weight.squeeze(-1) > 0 - delta = _ctx_scatter_gather_expert_blocked( + routing_weight = rw[:, slot, :] + T2Ei = routing_weight > 0 + expert_out = _cumsum_scatter_gather_update_expert_blocked( x=x, T2Ei=T2Ei, W_g=W_g[:, slot], W_u=W_u[:, slot], W_d=W_d[:, slot], + routing_weight=routing_weight, + expert_out=expert_out, act_fn=self.experts[0].act_fn, T=T, + packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE, ) - expert_out_partial = expert_out_partial + (delta * routing_weight) - return expert_out_partial.sum(dim=0) + return expert_out.sum(dim=0) def orig_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: B, S, H = hidden_states.shape From 1ae7b239ed471740df783e37a77a7aeac3e0e86a Mon Sep 17 00:00:00 2001 From: vbaddi Date: Thu, 30 Apr 2026 07:36:31 +0530 Subject: [PATCH 7/7] nit(0429): update modeling files Signed-off-by: vbaddi --- .../transformers/models/gpt_oss/modeling_gpt_oss.py | 8 +------- .../transformers/models/qwen3_moe/modeling_qwen3_moe.py | 8 +------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 5e0270b7b..53dd72193 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -60,13 +60,7 @@ def __qeff_init__(self): def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: - """Build packed->original token index table for an NSP-sliced expert mask. - - Given ``T2Ei`` of shape ``[num_nsp, T]`` marking which tokens are routed to - an expert, produces an index tensor where ``matched_idx[b, j]`` is the - original token position in ``x`` that lands at packed position ``j`` for - NSP lane ``b`` (or ``INT32_MAX`` when ``j`` is past the last valid row). - """ + """Build packed->original token index""" batch_size, seq_len = T2Ei.shape int32_max = torch.iinfo(torch.int32).max int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device) diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 939d8faa9..942ebdc73 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -111,13 +111,7 @@ def eager_attention_forward( def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor: - """Build packed->original token index table for an NSP-sliced expert mask. - - Given ``T2Ei`` of shape ``[num_nsp, T]`` marking which tokens are routed to - an expert, produces an index tensor where ``matched_idx[b, j]`` is the - original token position in ``x`` that lands at packed position ``j`` for - NSP lane ``b`` (or ``INT32_MAX`` when ``j`` is past the last valid row). - """ + """Build packed->original token index""" batch_size, seq_len = T2Ei.shape int32_max = torch.iinfo(torch.int32).max int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device)