Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 24 additions & 44 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,16 +1051,16 @@ def export(
Path to the generated ONNX graph file for the language decoder.
"""
if prefill_only:
if prefill_seq_len > 1:
if not enable_chunking and self.continuous_batching:
raise NotImplementedError(
"Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
)
self.hash_params["prefill_only"] = True
self.prefill(enable=True, enable_chunking=enable_chunking)
else:
self.hash_params["prefill_only"] = False
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
assert prefill_seq_len > 1
if not enable_chunking and self.continuous_batching:
raise NotImplementedError(
"Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
)
self.hash_params["prefill_only"] = True
self.prefill(enable=True, enable_chunking=enable_chunking)
else:
self.hash_params["prefill_only"] = False
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))

return self._export(
inputs,
Expand Down Expand Up @@ -1242,24 +1242,6 @@ def onnx_path(self):
"""
return [self.vision_model.onnx_path, self.lang_model.onnx_path]

@property
def qpc_path(self):
"""
Get the QPC paths for the vision and language model components.

Returns
-------
Union[List[str], str, None]
A list containing both QPC paths if both are compiled, or just one if only one is,
or None if neither is compiled.
"""
if self.vision_model.qpc_path and self.lang_model.qpc_path:
return [self.vision_model.qpc_path, self.lang_model.qpc_path]
elif self.vision_model.qpc_path:
return self.vision_model.qpc_path
else:
return self.lang_model.qpc_path

def export(
self,
export_dir: Optional[str] = None,
Expand Down Expand Up @@ -1372,7 +1354,7 @@ def compile(
skip_vision: Optional[bool] = False,
skip_lang: Optional[bool] = False,
use_onnx_subfunctions: bool = False,
prefill_only=False,
prefill_only=None,
enable_chunking=False,
**compiler_options,
) -> str:
Expand Down Expand Up @@ -1492,11 +1474,7 @@ def compile(
if lang_onnx_path:
self.lang_model.onnx_path = lang_onnx_path

if (
(self.vision_model.onnx_path is None and vision_onnx_path is None)
or (self.lang_model.onnx_path is None and lang_onnx_path is None)
or prefill_only
):
if vision_onnx_path is None or lang_onnx_path is None:
self.export(
use_onnx_subfunctions=use_onnx_subfunctions,
skip_vision=skip_vision,
Expand All @@ -1510,8 +1488,9 @@ def compile(
compiler_options.pop("continuous_batching", None)
compiler_options.pop("kv_cache_batch_size", None)
compiler_options.pop("full_batch_size", None)
self.qpc_paths = {}
if not skip_vision:
self.vision_model._compile(
vision_qpc_path = self.vision_model._compile(
compile_dir=compile_dir,
compile_only=True,
specializations=specializations["vision"],
Expand All @@ -1524,6 +1503,7 @@ def compile(
use_onnx_subfunctions=use_onnx_subfunctions,
**compiler_options,
)
self.qpc_paths["vision_qpc_path"] = vision_qpc_path

# Custom NPI file options
if hasattr(self.model, "get_npi_file") and "node_precision_info" not in compiler_options:
Expand All @@ -1548,16 +1528,17 @@ def compile(
if ("vision_embeds" in output_name or "deepstack_features" in output_name)
else kv_cache_dtype
)

if prefill_only:
if prefill_seq_len > 1:
specializations = specializations["lang"][:1] # prefill
else:
specializations = specializations["lang"][-1:] # decoder
specializations = specializations["lang"][:1]
qpc_key = "lang_prefill_qpc_path"
elif prefill_seq_len == 1:
specializations = specializations["lang"][-1:]
qpc_key = "lang_decode_qpc_path"
else:
specializations = specializations["lang"]
qpc_key = "lang_qpc_path"

self.lang_model._compile(
lang_qpc_path = self.lang_model._compile(
compile_dir=compile_dir,
compile_only=True,
retained_state=True,
Expand All @@ -1571,9 +1552,8 @@ def compile(
use_onnx_subfunctions=use_onnx_subfunctions,
**compiler_options,
)
if skip_vision and prefill_only: # for disagg serving
return self.lang_model.qpc_path
return self.qpc_path
self.qpc_paths.update({qpc_key: lang_qpc_path})
return self.qpc_paths

def generate(
self,
Expand Down
20 changes: 19 additions & 1 deletion QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
# -----------------------------------------------------------------------------
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -562,6 +562,15 @@ def __init__(self, model):
self.model = model
self.model.vision_model = self.model.visual

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Return the set of class used as the repeated layer across the model for subfunction extraction.
Notes:
This method should return the *class object* (not an instance).
Downstream code can use this to find/build subfunctions for repeated blocks.
"""
return {self.model.visual.blocks[0].__class__}

def forward(self, pixel_values, image_grid_thw):
image_embeds, deepstack_feature_lists = self.model.visual(pixel_values, grid_thw=image_grid_thw)
bs = image_grid_thw.shape[0]
Expand All @@ -580,6 +589,15 @@ def __init__(self, model):
self.model = model
self.language_model = self.model.model

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Return the set of class used as the repeated layer across the model for subfunction extraction.
Notes:
This method should return the *class object* (not an instance).
Downstream code can use this to find/build subfunctions for repeated blocks.
"""
return {QEffQwen3VLTextDecoderLayer}

def forward(
self,
input_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
# -----------------------------------------------------------------------------
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -629,7 +629,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
router_logits = self.gate(x) # [T, E]
prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype)
top_w, top_i = torch.topk(prob, self.top_k, dim=-1) # [T, k], [T, k]
top_w = top_w / top_w.sum(dim=-1, keepdim=True)
top_w = top_w / torch.einsum("bi->b", top_w)[:, None]
top_w = top_w.to(hidden_states.dtype)

# gate_up_proj: [E, H, 2I], down_proj: [E, I, H]
Expand Down Expand Up @@ -711,6 +711,15 @@ def __init__(self, model):
self.model = model
self.model.vision_model = self.model.visual

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Return the set of class used as the repeated layer across the model for subfunction extraction.
Notes:
This method should return the *class object* (not an instance).
Downstream code can use this to find/build subfunctions for repeated blocks.
"""
return {self.model.visual.blocks[0].__class__}

def forward(self, pixel_values, image_grid_thw):
image_embeds, deepstack_feature_lists = self.model.visual(pixel_values, grid_thw=image_grid_thw)
bs = image_grid_thw.shape[0]
Expand All @@ -727,7 +736,16 @@ class QEffQwen3VLDecoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.language_model = self.model.model
self.language_model = self.model.model.language_model

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Return the set of class used as the repeated layer across the model for subfunction extraction.
Notes:
This method should return the *class object* (not an instance).
Downstream code can use this to find/build subfunctions for repeated blocks.
"""
return {QEffQwen3VLMoeTextDecoderLayer}

def forward(
self,
Expand Down Expand Up @@ -790,7 +808,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
router_logits = self.gate(x)
prob = F.softmax(router_logits, dim=-1, dtype=torch.float)
top_w, top_i = torch.topk(prob, self.top_k, dim=-1)
top_w = top_w / top_w.sum(dim=1, keepdim=True)
top_w = top_w / torch.einsum("bi->b", top_w)[:, None]
top_w = top_w.to(x.dtype)
idx = top_i.reshape(-1)
w_up = self.experts.gate_up_proj.index_select(0, idx)
Expand All @@ -805,9 +823,8 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
intermediate = up * self.experts.act_fn(gate)
experts_out = torch.bmm(intermediate, w_dn)
experts_out = experts_out.view(T, self.top_k, H) * top_w.unsqueeze(-1)
experts_out = experts_out.sum(dim=1).view(B, S, H)

return experts_out, router_logits
experts_out = torch.einsum("bnd->bd", experts_out)
return experts_out.view(B, S, H), router_logits


class QEffQwen3VLMoeForConditionalGeneration(Qwen3VLMoeForConditionalGeneration):
Expand Down
Loading