Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,8 @@ def _compile(
onnx_path = Path(
onnx_path
if onnx_path
else self.onnx_path
if self.onnx_path
# else self.onnx_path
# if self.onnx_path
else self.get_onnx_path(
prefill_only,
enable_chunking,
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/customop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from QEfficient.customop.ctx_scatter_gather import (
CtxGatherFunc,
CtxGatherFunc3D,
CtxGatherFunc3DGeneralized,
CtxGatherFuncBlockedKV,
CtxScatterFunc,
CtxScatterFunc3D,
CtxScatterFunc3DGeneralized,
CtxScatterFunc3DInt,
)
from QEfficient.customop.ctx_scatter_gather_cb import (
CtxGatherFuncBlockedKVCB,
Expand All @@ -26,7 +29,10 @@
"CtxGatherFuncBlockedKV",
"CtxScatterFunc",
"CtxGatherFunc3D",
"CtxGatherFunc3DGeneralized",
"CtxScatterFunc3D",
"CtxScatterFunc3DGeneralized",
"CtxScatterFunc3DInt",
"CustomRMSNormAIC",
"GemmaCustomRMSNormAIC",
"CtxGatherFuncCB",
Expand Down
100 changes: 99 additions & 1 deletion QEfficient/customop/ctx_scatter_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates

# Create indices
batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2]), exp_shape)

# keep index tensor types aligned for backend that require exact dtype match
batch_idx = ops.Cast(batch_idx, to=onnxscript.INT32.dtype)
ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [2]), exp_shape)
indices = ops.Concat(batch_idx, ctx_idx, axis=2)

Expand All @@ -78,8 +81,9 @@ def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates
class CtxScatterFunc3D(torch.autograd.Function):
@staticmethod
def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor):
data = data.clone()
Comment thread
vbaddi marked this conversation as resolved.
batch_idx = torch.arange(data.shape[0]).view(-1, 1)
ctx_idx = position_ids
ctx_idx = torch.where(position_ids == torch.iinfo(torch.int32).max, data.shape[1] - 1, position_ids)
data[batch_idx, ctx_idx] = updates
return data

Expand All @@ -92,6 +96,74 @@ def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updat
return g.onnxscript_op(CtxScatter3D, data, position_ids, updates).setTypeAs(data)


class CtxScatterFunc3DGeneralized(torch.autograd.Function):
"""Scatter variant that preserves ``data`` at invalid (INT32_MAX) positions.

Unlike :class:`CtxScatterFunc3D`, which writes updates for invalid rows to
``data.shape[1]-1`` (potentially clobbering valid content), this version
masks out invalid rows before scattering so ``data`` is left untouched where
``position_ids == INT32_MAX``.
"""

@staticmethod
def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor):
data = data.clone()
valid = position_ids != torch.iinfo(torch.int32).max
batch_idx = torch.arange(data.shape[0], device=data.device).view(-1, 1).expand_as(position_ids)
data[batch_idx[valid], position_ids[valid].long()] = updates[valid]
return data

@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxScatter3D, data, position_ids, updates).setTypeAs(data)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxScatter3DInt(
data: onnxscript.INT32, position_ids: onnxscript.INT32, updates: onnxscript.INT32
) -> onnxscript.INT32:
# Find dims
batch_size = ops.Gather(ops.Shape(data), [0])
seq_len = ops.Gather(ops.Shape(position_ids), [1])

# Expanded shape to create indices
zero = ops.Constant(value_ints=[0])
one = ops.Constant(value_ints=[1])
exp_shape = ops.Concat(batch_size, seq_len, one, axis=0)

# Create indices
batch_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, batch_size, one), [1, 2]), exp_shape)
batch_idx = ops.Cast(batch_idx, to=onnxscript.INT32.dtype)
ctx_idx = ops.Expand(ops.Unsqueeze(position_ids, [2]), exp_shape)
indices = ops.Concat(batch_idx, ctx_idx, axis=2)

return ops.ScatterND(data, indices, updates)


class CtxScatterFunc3DInt(torch.autograd.Function):
"""Int32-typed scatter used to build a packed->original index table."""

@staticmethod
def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor):
data = data.clone()
valid = position_ids != torch.iinfo(torch.int32).max
batch_idx = torch.arange(data.shape[0], device=data.device).view(-1, 1).expand_as(position_ids)
data[batch_idx[valid], position_ids[valid].long()] = updates[valid]
return data

@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxScatter3DInt, data, position_ids, updates).setTypeAs(data)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGather3D(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[2], axes=[0]))
Expand All @@ -103,6 +175,7 @@ class CtxGatherFunc3D(torch.autograd.Function):
@staticmethod
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
batch_indices = torch.arange(data.shape[0]).view(-1, 1)
ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices)
return data[batch_indices, ctx_indices]

@staticmethod
Expand All @@ -114,6 +187,31 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor
return g.onnxscript_op(CtxGather3D, data, ctx_indices).setTypeAs(data)


class CtxGatherFunc3DGeneralized(torch.autograd.Function):
"""Gather variant that tolerates INT32_MAX indices (invalid rows read from 0).

Semantically equivalent to :class:`CtxGatherFunc3D` on the PyTorch side but
exposed as a separate autograd op so callers using the packed/cumsum scatter
pipeline can be easily recognized and so the ONNX symbolic omits
``setTypeAs`` (needed when the caller already has a matching dtype on
``data`` and wants the op signature to flow through without dtype pinning).
"""

@staticmethod
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
batch_indices = torch.arange(data.shape[0]).view(-1, 1)
ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices)
return data[batch_indices, ctx_indices]

@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGather3D, data, ctx_indices)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGather(
data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"}

# This is for supporting different modelling classes specially written for prefill-only model
SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss", "kimi_k2", "kimi_k25"}
SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss", "qwen3_moe", "kimi_k2", "kimi_k25"}

_PROXY_ONLY_ONNX_TRANSFORMS = (FP16ClipTransform, SplitTensorsTransform)

Expand Down
144 changes: 144 additions & 0 deletions QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
generic_blocked_attention_interface,
past_key_value_update,
)
from QEfficient.customop.ctx_scatter_gather import (
CtxGatherFunc3DGeneralized,
CtxScatterFunc3DGeneralized,
CtxScatterFunc3DInt,
)
from QEfficient.transformers.cache_utils import QEffHybridCacheForGPTOSS
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
Expand All @@ -50,7 +55,142 @@ def __qeff_init__(self):
self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim))


EXPERT_BLOCKING_NUM_NSP = int(os.environ.get("EXPERT_BLOCKING_NUM_NSP", "16"))
EXPERT_BLOCKING_PACKED_CHUNK_SIZE = int(os.environ.get("EXPERT_BLOCKING_PACKED_CHUNK_SIZE", "256"))


def _build_matched_idx_from_cumsum(T2Ei: torch.Tensor) -> torch.Tensor:
"""Build packed->original token index"""
batch_size, seq_len = T2Ei.shape
int32_max = torch.iinfo(torch.int32).max
int32_max_scalar = torch.tensor(int32_max, dtype=torch.int32, device=T2Ei.device)
token_idx = torch.arange(seq_len, dtype=torch.int32, device=T2Ei.device).unsqueeze(0).expand(batch_size, -1)
valid_prefix = torch.cumsum(T2Ei.to(torch.int32), dim=1)
valid_dest = valid_prefix - 1
scatter_pos = torch.where(T2Ei, valid_dest, int32_max_scalar)
# Once the compiler fix for ConstantOfShape(INT32_MAX) is available, this
# can be switched back to ``torch.full_like(token_idx, int32_max)``.
matched_idx = int32_max_scalar.expand_as(token_idx)
matched_idx = CtxScatterFunc3DInt.apply(
matched_idx.unsqueeze(-1),
scatter_pos,
token_idx.unsqueeze(-1),
).squeeze(-1)
return matched_idx


def _cumsum_scatter_gather_update_gptoss_expert_blocked(
x: torch.Tensor,
T2Ei: torch.Tensor,
W_g: torch.Tensor,
W_u: torch.Tensor,
W_d: torch.Tensor,
b_g: torch.Tensor,
b_u: torch.Tensor,
b_d: torch.Tensor,
routing_weight: torch.Tensor,
expert_out: torch.Tensor,
limit: float,
alpha: float,
T: int,
packed_chunk_size: int,
) -> torch.Tensor:
"""Cumsum-scatter-gather-update expert helper for GPT-OSS NSP-blocked dispatch.

Same algorithm as the Qwen3-MOE version but with GPT-OSS biases and GLU
activation (clamped gate/up, ``(up + 1) * gate * sigmoid(gate * alpha)``).

Shapes:
x : [T, H]
T2Ei : [num_nsp, T] (bool)
W_g, W_u : [num_nsp, H, I]
W_d : [num_nsp, I, H]
b_g, b_u : [num_nsp, I]
b_d : [num_nsp, H]
routing_weight : [num_nsp, T]
expert_out : [num_nsp, T, H] (accumulator, in-out)
"""
batch_size, seq_len = T2Ei.shape
packed_chunk_size = max(1, min(packed_chunk_size, seq_len))

matched_idx = _build_matched_idx_from_cumsum(T2Ei)
valid_rows = T2Ei.to(torch.int32).sum(dim=1, keepdim=True)
row_range = torch.arange(packed_chunk_size, dtype=torch.int32, device=x.device).unsqueeze(0)
x_expanded = x.unsqueeze(0).expand(batch_size, -1, -1)
rw_expanded = routing_weight.unsqueeze(-1)

for packed_start in range(0, seq_len, packed_chunk_size):
packed_stop = packed_start + packed_chunk_size
chunk_matched_idx = matched_idx[:, packed_start:packed_stop]

x_chunk = CtxGatherFunc3DGeneralized.apply(x_expanded, chunk_matched_idx)

gate = (x_chunk @ W_g) + b_g.unsqueeze(1)
up = (x_chunk @ W_u) + b_u.unsqueeze(1)
gate = gate.clamp(min=torch.finfo(torch.float16).min, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
intermediate = (up + 1) * glu
down_chunk = (intermediate @ W_d) + b_d.unsqueeze(1)

rw_chunk = CtxGatherFunc3DGeneralized.apply(rw_expanded, chunk_matched_idx)
down_chunk = down_chunk * rw_chunk

expert_out_chunk = CtxGatherFunc3DGeneralized.apply(expert_out, chunk_matched_idx)
updated_chunk = expert_out_chunk + down_chunk

chunk_valid_rows = torch.clamp(valid_rows - packed_start, min=0, max=packed_chunk_size)
updated_chunk = torch.where(
(row_range < chunk_valid_rows).unsqueeze(-1), updated_chunk, torch.zeros_like(updated_chunk)
)
expert_out = CtxScatterFunc3DGeneralized.apply(expert_out, chunk_matched_idx, updated_chunk)

return expert_out


class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP):
def _forward_expert_blocked(self, x: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
T, H = x.shape
num_nsp = EXPERT_BLOCKING_NUM_NSP
num_experts = self.experts.num_experts
if num_experts % num_nsp != 0:
raise ValueError(f"num_experts ({num_experts}) must be divisible by EXPERT_BLOCKING_NUM_NSP ({num_nsp})")

local_experts = num_experts // num_nsp
expert_dim = self.experts.expert_dim
routing_weights_by_expert = (
routing_weights.transpose(0, 1).contiguous().view(local_experts, num_nsp, T).transpose(0, 1).contiguous()
)
W_g = self.experts.gate_proj.view(local_experts, num_nsp, H, expert_dim).transpose(0, 1).contiguous()
W_u = self.experts.up_proj.view(local_experts, num_nsp, H, expert_dim).transpose(0, 1).contiguous()
W_d = self.experts.down_proj.view(local_experts, num_nsp, expert_dim, H).transpose(0, 1).contiguous()
b_g = self.experts.gate_proj_bias.view(local_experts, num_nsp, expert_dim).transpose(0, 1).contiguous()
b_u = self.experts.up_proj_bias.view(local_experts, num_nsp, expert_dim).transpose(0, 1).contiguous()
b_d = self.experts.down_proj_bias.view(local_experts, num_nsp, H).transpose(0, 1).contiguous()

expert_out = x.new_zeros((num_nsp, T, H))
for local_slot in range(local_experts):
routing_weight = routing_weights_by_expert[:, local_slot, :]
T2Ei = routing_weight > 0
expert_out = _cumsum_scatter_gather_update_gptoss_expert_blocked(
x=x,
T2Ei=T2Ei,
W_g=W_g[:, local_slot],
W_u=W_u[:, local_slot],
W_d=W_d[:, local_slot],
b_g=b_g[:, local_slot],
b_u=b_u[:, local_slot],
b_d=b_d[:, local_slot],
routing_weight=routing_weight,
expert_out=expert_out,
limit=self.experts.limit,
alpha=self.experts.alpha,
T=T,
packed_chunk_size=EXPERT_BLOCKING_PACKED_CHUNK_SIZE,
)

return expert_out.sum(dim=0)

def forward(self, hidden: torch.Tensor):
B, S, H = hidden.shape
T = B * S
Expand All @@ -69,6 +209,10 @@ def forward(self, hidden: torch.Tensor):
# Routing weights for each expert [T, E]
routing_weights = masked_logits

if self.experts.num_experts % EXPERT_BLOCKING_NUM_NSP == 0:
expert_out = self._forward_expert_blocked(x=hidden, routing_weights=routing_weights)
return expert_out.view(B, S, H), router_logits

# ────────────────── allocate the output tensor ─────
expert_out = hidden.new_zeros((T, H)) # accumulation buffer

Expand Down
Loading
Loading