Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@
from QEfficient.utils.logging_utils import logger


def _should_export_embedding_output(module) -> bool:
for holder in (module, getattr(module, "model", None)):
if holder is None:
continue
qaic_config = getattr(holder, "qaic_config", None)
if isinstance(qaic_config, dict) and qaic_config.get("export_embedding", False):
return True
config = getattr(holder, "config", None)
if config is not None and getattr(config, "export_embedding", False):
return True
return False


def qeff_apply_interleaved_mrope(freqs, mrope_section):
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
Expand Down Expand Up @@ -742,6 +755,8 @@ def forward(
hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index]
logits = self.model.lm_head(hidden_states)
image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
if _should_export_embedding_output(self):
return logits, vision_embeds, deepstack_features, image_idx, hidden_states, outputs.past_key_values
return logits, vision_embeds, deepstack_features, image_idx, outputs.past_key_values


Expand Down Expand Up @@ -839,6 +854,8 @@ def forward(
hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index]
logits = self.lm_head(hidden_states)
image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
if _should_export_embedding_output(self):
return logits, image_embeds, image_idx, hidden_states, outputs.past_key_values
return logits, image_embeds, image_idx, outputs.past_key_values

def get_dummy_inputs(
Expand Down Expand Up @@ -1162,11 +1179,15 @@ def get_output_names(self, kv_offload: bool = False):
lang_output_names.insert(1, "vision_embeds_RetainedState")
lang_output_names.insert(2, "image_idx_output")
lang_output_names.insert(2, "deepstack_features_RetainedState")
if _should_export_embedding_output(self):
lang_output_names.insert(4, "embedding_output")
output_names["vision"] = vision_output_names
output_names["lang"] = lang_output_names
else:
lang_output_names.insert(1, "pixel_values_RetainedState")
lang_output_names.insert(2, "image_idx_output")
if _should_export_embedding_output(self):
lang_output_names.insert(3, "embedding_output")
return lang_output_names
return output_names

Expand Down
Loading
Loading