Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion QEfficient/generation/embedding_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -
inputs = self._qeff_model.model.prepare_inputs_for_generation(
inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0]
)

if (
hasattr(self._qeff_model.model.config, "model_type")
and self._qeff_model.model.config.model_type == "qwen3_vl_moe"
Expand Down
4 changes: 2 additions & 2 deletions QEfficient/generation/vlm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(
self.is_qwen2_5_vl = (
hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen2_5_vl"
)
self.is_qwen3_vl_moe=(
self.is_qwen3_vl_moe = (
hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen3_vl_moe"
)
self.qeff_model = qeff_model
Expand Down Expand Up @@ -262,7 +262,7 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len):
if self.is_qwen2_5_vl:
_ = self.update_decode_inputs_qwen2_5_vl(outputs, position_ids, generation_len, decode_batch_id)
elif self.is_qwen3_vl_moe:
_ = self.update_decode_inputs_qwen3_vl_moe(outputs,position_ids,generation_len,decode_batch_id)
_ = self.update_decode_inputs_qwen3_vl_moe(outputs, position_ids, generation_len, decode_batch_id)
else:
_ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id)

Expand Down
3 changes: 0 additions & 3 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def write_only(self, key_states, value_states, cache_kwargs):
self.keys = key_states
self.values = value_states
else:
# breakpoint()
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs

Expand Down Expand Up @@ -192,7 +191,6 @@ def update(
Return:
A tuple containing the updated key and value states.
"""
# breakpoint()
# Update the cache
# if not self.is_initialized:

Expand Down Expand Up @@ -371,7 +369,6 @@ def read_only(self, layer_idx, cache_kwargs):
Return:
A tuple containing the updated key and value states.
"""
# breakpoint()
return self.layers[layer_idx].read_only(cache_kwargs)

def read_only_blockedKV(self, start_index, end_index, layer_idx, cache_kwargs):
Expand Down
120 changes: 96 additions & 24 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,36 @@ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None:
BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks"))

def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs):
def prefill(
self,
enable: Optional[bool] = True,
enable_chunking: Optional[bool] = False,
retain_full_kv: Optional[bool] = False,
):
if enable:
if enable_chunking:
self.model, tf = PrefillOnlyChunkedTransform.apply(self.model)
else:
self.model, tf = PrefillOnlyTransform.apply(self.model)

else:
if retain_full_kv:
self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model)
else:
self.model, tf = RevertPrefillOnlyTransform.apply(self.model)

def export(
self,
inputs,
output_names,
dynamic_axes,
export_dir=None,
offload_pt_weights=True,
prefill_seq_len: Optional[int] = None,
prefill_only: bool = False,
enable_chunking: bool = False,
**kwargs,
):
"""
Exports the language decoder component to ONNX format.

Expand All @@ -1021,6 +1050,18 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt
str
Path to the generated ONNX graph file for the language decoder.
"""
if prefill_only:
if prefill_seq_len > 1:
if not enable_chunking and self.continuous_batching:
raise NotImplementedError(
"Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
)
self.hash_params["prefill_only"] = True
self.prefill(enable=True, enable_chunking=enable_chunking)
else:
self.hash_params["prefill_only"] = False
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))

return self._export(
inputs,
output_names=output_names,
Expand Down Expand Up @@ -1223,6 +1264,11 @@ def export(
self,
export_dir: Optional[str] = None,
use_onnx_subfunctions: bool = False,
skip_vision: Optional[bool] = False,
skip_lang: Optional[bool] = False,
prefill_seq_len: Optional[int] = None,
prefill_only: bool = False,
enable_chunking: bool = False,
**kwargs,
) -> str:
"""
Expand Down Expand Up @@ -1276,26 +1322,33 @@ def export(
vocab_size=self.model.language_model.config.vocab_size,
qaic_config=self.lang_model.model.qaic_config,
)
if not skip_vision:
self.vision_model.export(
inputs["vision"],
output_names["vision"],
dynamic_axes["vision"],
export_dir=export_dir,
offload_pt_weights=False,
use_onnx_subfunctions=use_onnx_subfunctions,
)

self.vision_model.export(
inputs["vision"],
output_names["vision"],
dynamic_axes["vision"],
export_dir=export_dir,
offload_pt_weights=False,
use_onnx_subfunctions=use_onnx_subfunctions,
)

offload_pt_weights = kwargs.get("offload_pt_weights", True)
self.lang_model.export(
inputs["lang"],
output_names["lang"],
dynamic_axes["lang"],
export_dir=export_dir,
offload_pt_weights=offload_pt_weights,
use_onnx_subfunctions=use_onnx_subfunctions,
)
if prefill_only and prefill_seq_len > 1:
offload_pt_weights = False # to keep weight for decode onnx
else:
offload_pt_weights = kwargs.get("offload_pt_weights", True)

if not skip_lang:
self.lang_model.export(
inputs["lang"],
output_names["lang"],
dynamic_axes["lang"],
export_dir=export_dir,
offload_pt_weights=offload_pt_weights,
use_onnx_subfunctions=use_onnx_subfunctions,
prefill_only=prefill_only,
enable_chunking=enable_chunking,
prefill_seq_len=prefill_seq_len,
)
return self.onnx_path

def compile(
Expand All @@ -1319,6 +1372,8 @@ def compile(
skip_vision: Optional[bool] = False,
skip_lang: Optional[bool] = False,
use_onnx_subfunctions: bool = False,
prefill_only=False,
enable_chunking=False,
**compiler_options,
) -> str:
"""
Expand Down Expand Up @@ -1437,11 +1492,18 @@ def compile(
if lang_onnx_path:
self.lang_model.onnx_path = lang_onnx_path

if (self.vision_model.onnx_path is None and vision_onnx_path is None) or (
self.lang_model.onnx_path is None and lang_onnx_path is None
if (
(self.vision_model.onnx_path is None and vision_onnx_path is None)
or (self.lang_model.onnx_path is None and lang_onnx_path is None)
or prefill_only
):
self.export(
use_onnx_subfunctions=use_onnx_subfunctions,
skip_vision=skip_vision,
skip_lang=skip_lang,
prefill_only=prefill_only,
enable_chunking=enable_chunking,
prefill_seq_len=prefill_seq_len,
)

# TODO this hould be removed once the continous batching is supported for all the models.
Expand Down Expand Up @@ -1486,11 +1548,20 @@ def compile(
if ("vision_embeds" in output_name or "deepstack_features" in output_name)
else kv_cache_dtype
)

if prefill_only:
if prefill_seq_len > 1:
specializations = specializations["lang"][:1] # prefill
else:
specializations = specializations["lang"][-1:] # decoder
else:
specializations = specializations["lang"]

self.lang_model._compile(
onnx_path=self.lang_model.onnx_path,
compile_dir=compile_dir,
compile_only=True,
retained_state=True,
specializations=specializations["lang"],
specializations=specializations,
convert_to_fp16=True,
mxfp6_matmul=mxfp6_matmul,
mdp_ts_num_devices=num_devices,
Expand All @@ -1500,6 +1571,8 @@ def compile(
use_onnx_subfunctions=use_onnx_subfunctions,
**compiler_options,
)
if skip_vision and prefill_only: # for disagg serving
return self.lang_model.qpc_path
return self.qpc_path

def generate(
Expand Down Expand Up @@ -1628,7 +1701,6 @@ def kv_offload_generate(
AssertionError
If `generation_len` is not greater than zero.
"""
# breakpoint()
if not self.lang_model.qpc_path:
raise TypeError("Please run compile API for language model first!")

Expand Down
3 changes: 3 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@
QEffQwen3Model,
)
from QEfficient.transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock,
QEffQwen3VLMoeForConditionalGeneration,
QEffQwen3VLMoeModel,
QEffQwen3VLMoeTextAttention,
Expand Down Expand Up @@ -646,6 +647,8 @@ class PrefillOnlyChunkedTransform(ModuleMappingTransform):
QEffGptOssModel: QEffPrefillOnlyGptOssModel,
QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP,
# Qwen3 VL Moe
QEffQwen3VLMoeTextSparseMoeBlock: QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
# -----------------------------------------------------------------------------
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -659,6 +659,15 @@ def __init__(self, model):
self.model = model
self.model.vision_model = self.model.visual

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Return the set of class used as the repeated layer across the model for subfunction extraction.
Notes:
This method should return the *class object* (not an instance).
Downstream code can use this to find/build subfunctions for repeated blocks.
"""
return {self.model.visual.blocks[0].__class__}

def forward(self, pixel_values, image_grid_thw):
image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw)[0]
bs = image_grid_thw.shape[0]
Expand All @@ -671,7 +680,16 @@ class QEffQwen3VLDecoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.language_model = self.model.model
self.language_model = self.model.model.language_model

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Return the set of class used as the repeated layer across the model for subfunction extraction.
Notes:
This method should return the *class object* (not an instance).
Downstream code can use this to find/build subfunctions for repeated blocks.
"""
return {QEffQwen3VLDecoderWrapper}

def forward(
self,
Expand Down Expand Up @@ -714,7 +732,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
x = hidden_states.view(T, H)

router_logits = self.gate(x)
prob = F.softmax(router_logits, dim=-1, dtype=torch.float)
prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype)
top_w, top_i = torch.topk(prob, self.top_k, dim=-1)
top_w = top_w / top_w.sum(dim=1, keepdim=True)
top_w = top_w.to(x.dtype)
Expand All @@ -736,6 +754,40 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
return experts_out, router_logits


class QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock(Qwen3VLMoeTextSparseMoeBlock):
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
B, S, H = hidden_states.shape
T = B * S
x = hidden_states.view(T, H)
act = getattr(self.experts, "act_fn", F.silu)

router_logits = self.gate(x) # [T, E]
prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype)
top_w, top_i = torch.topk(prob, self.top_k, dim=-1) # [T, k], [T, k]
top_w = top_w / top_w.sum(dim=-1, keepdim=True)
top_w = top_w.to(hidden_states.dtype)

# gate_up_proj: [E, H, 2I], down_proj: [E, I, H]
W_up = self.experts.gate_up_proj
W_dn = self.experts.down_proj
E, H_w, twoI = W_up.shape
I2 = twoI // 2
routing_weights = torch.zeros_like(prob, dtype=hidden_states.dtype) # [T, E]
routing_weights.scatter_(1, top_i, top_w)
expert_out = x.new_zeros((T, H))
for e in range(E):
rw = routing_weights[:, e].unsqueeze(-1) # [T, 1]
# Split fused [H, 2I] -> [H, I] + [H, I]
W_gate_e = W_up[e, :, :I2]
W_up_e = W_up[e, :, I2:]
W_dn_e = W_dn[e, :, :]
gate = x @ W_gate_e
up = x @ W_up_e
down = (up * act(gate)) @ W_dn_e
expert_out.add_(down * rw)
return expert_out.view(B, S, H), router_logits


class QEffQwen3VLMoeForConditionalGeneration(Qwen3VLMoeForConditionalGeneration):
def get_qeff_vision_encoder(self):
return QEffQwen3VLEncoderWrapper(self)
Expand Down
1 change: 1 addition & 0 deletions examples/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# -----------------------------------------------------------------------------

import requests
import transformers
from PIL import Image
from qwen_vl_utils import process_vision_info
from transformers import AutoConfig, AutoProcessor, TextStreamer
Expand Down
Loading