From 3236a4d44bb9aa8b72c663c81b62f0534ec81f61 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Sat, 18 Apr 2026 10:55:00 +0000 Subject: [PATCH] Enabling support of rerankers models 2B and 8B of qwen3vl bucket Signed-off-by: Amit Raj --- .../transformers/models/modeling_auto.py | 7 +- .../models/qwen3_vl/modeling_qwen3_vl.py | 15 +- QEfficient/utils/generate_inputs.py | 6 +- QEfficient/utils/test_utils.py | 2 + docs/source/validate.md | 7 + examples/image_text_to_text/README.md | 6 + .../models/qwen3vl/reranker/README.md | 52 ++ .../qwen3vl/reranker/qwen3_vl_reranker.py | 555 ++++++++++++++++++ scripts/Jenkinsfile | 96 +-- tests/configs/image_text_model_configs.json | 36 +- .../image_text_to_text/test_reranker_mad.py | 455 ++++++++++++++ 11 files changed, 1185 insertions(+), 52 deletions(-) create mode 100644 examples/image_text_to_text/models/qwen3vl/reranker/README.md create mode 100644 examples/image_text_to_text/models/qwen3vl/reranker/qwen3_vl_reranker.py create mode 100644 tests/transformers/models/image_text_to_text/test_reranker_mad.py diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 5f0eaf2b78..7d882729b2 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1328,6 +1328,7 @@ def export( kv_offload=True, continuous_batching=self.continuous_batching, comp_ctx_lengths=self.comp_ctx_lengths_decode, + prefill_seq_len=prefill_seq_len, ) dynamic_axes = self.model.get_onnx_dynamic_axes( kv_offload=True, @@ -1335,7 +1336,11 @@ def export( comp_ctx_lengths=self.comp_ctx_lengths_decode, ) except TypeError: - inputs = self.model.get_dummy_inputs(kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode) + inputs = self.model.get_dummy_inputs( + kv_offload=True, + comp_ctx_lengths=self.comp_ctx_lengths_decode, + prefill_seq_len=prefill_seq_len, + ) dynamic_axes = self.model.get_onnx_dynamic_axes( kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode ) diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 6d6c6b42d6..d88cb567f8 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -848,8 +848,13 @@ def get_dummy_inputs( continuous_batching: bool = False, **kwargs, ): + prefill_seq_len = kwargs.get("prefill_seq_len") + if prefill_seq_len is None: + prefill_seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + prefill_seq_len = int(prefill_seq_len) + inputs_shapes = {} - inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, prefill_seq_len) # vision_size = 1024 vision_size = 187 inputs_shapes["vision_embeds"] = ( @@ -861,7 +866,7 @@ def get_dummy_inputs( inputs_shapes["position_ids"] = ( 3, constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, - constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + prefill_seq_len, ) inputs_shapes["pixel_values"] = (748, 1536) inputs_shapes["image_idx"] = (1, 1) @@ -881,8 +886,8 @@ def get_dummy_inputs( lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) lang_inputs["position_ids"] = ( ( - torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) - .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + torch.arange(prefill_seq_len, dtype=torch.int64) + .view(1, prefill_seq_len) .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) .unsqueeze(0) @@ -898,7 +903,7 @@ def get_dummy_inputs( kv_cache_shape = get_padding_shape_from_config( config=self.model.config.text_config, batch_size=fbs if continuous_batching else bs, - seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + seq_len=prefill_seq_len, ) lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.text_config.num_hidden_layers)] diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index bb24e1b84b..c5e42dd0b5 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -364,8 +364,9 @@ def update_vlm_ort_outputs(self, ort_outputs): Return: updated_outputs (Dict): Updated past_key_values, logits, pixel_values """ + num_layers = self.n_layer[0] if isinstance(self.n_layer, (list, tuple)) else self.n_layer present_key_values = [] - for i in range(self.n_layer[0]): + for i in range(num_layers): if "past_key." + str(i) + "_RetainedState" in ort_outputs: present_key_values.append(ort_outputs["past_key." + str(i) + "_RetainedState"]) if "past_value." + str(i) + "_RetainedState" in ort_outputs: @@ -397,7 +398,8 @@ def update_vlm_ort_inputs(self, inputs, ort_outputs): updated_inputs = {} updated_inputs["input_ids"] = ort_outputs["logits"].argmax(-1) updated_inputs["position_ids"] = np.max(inputs["position_ids"], axis=1, keepdims=True) + 1 - for i in range(self.n_layer[0]): + num_layers = self.n_layer[0] if isinstance(self.n_layer, (list, tuple)) else self.n_layer + for i in range(num_layers): updated_inputs["past_key." + str(i)] = ort_outputs["past_key_values"][i * 2] updated_inputs["past_value." + str(i)] = ort_outputs["past_key_values"][i * 2 + 1] if "pixel_values_RetainedState" in ort_outputs.keys(): diff --git a/QEfficient/utils/test_utils.py b/QEfficient/utils/test_utils.py index e6acc52f2d..b67534e5b4 100644 --- a/QEfficient/utils/test_utils.py +++ b/QEfficient/utils/test_utils.py @@ -472,6 +472,8 @@ class ModelConfig: "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen3-VL-30B-A3B-Instruct", "Qwen/Qwen3-VL-2B-Instruct", + "Qwen/Qwen3-VL-Reranker-2B", + "Qwen/Qwen3-VL-Reranker-8B", } EXTERNAL_MODELS = { diff --git a/docs/source/validate.md b/docs/source/validate.md index 6e639bb30b..9a7216fa20 100644 --- a/docs/source/validate.md +++ b/docs/source/validate.md @@ -84,6 +84,13 @@ | **Qwen2_5_VLForConditionalGeneration** | Qwen2.5-VL | [Qwen/Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) | ✔️ | ✔️ | ✕ | ✔️ | | **Mistral3ForConditionalGeneration** | Mistral3| [mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)| ✕ | ✔️ | ✕ | ✕ | +### Vision-Language Reranker Models (Text + Image Scoring) +**QEff Auto Class:** `QEFFAutoModelForImageTextToText` + +| Architecture | Model Family | Representative Models | Qeff Single Qpc | Qeff Dual Qpc | vllm Single Qpc | vllm Dual Qpc | +|------------------------------------|--------------|----------------------------------------------------------------------------------------|------------|---------------------|-------------------|-----------------| +| **Qwen3VLForConditionalGeneration** | Qwen3-VL Reranker | [Qwen/Qwen3-VL-Reranker-2B](https://huggingface.co/Qwen/Qwen3-VL-Reranker-2B)
[Qwen/Qwen3-VL-Reranker-8B](https://huggingface.co/Qwen/Qwen3-VL-Reranker-8B) | ✕ | ✔️ | ✕ | ✕ | + **Dual QPC:** diff --git a/examples/image_text_to_text/README.md b/examples/image_text_to_text/README.md index a6f1608b48..b58db258bd 100644 --- a/examples/image_text_to_text/README.md +++ b/examples/image_text_to_text/README.md @@ -100,12 +100,18 @@ Some models have specialized examples demonstrating advanced features: |-------|----------| | **Llama-4** | [models/llama4/](models/llama4/) | | **Qwen** | [models/qwen_vl/](models/qwen_vl/) | +| **Qwen3-VL Reranker** | [models/qwen3vl/reranker/](models/qwen3vl/reranker/) | | **Mistral** | [models/mistral_vision/](models/mistral_vision/) | | **Gemma** | [models/gemma_vision/](models/gemma_vision/) | | **Granite** | [models/granite_vision/](models/granite_vision/) | | **InternVL** | [models/internvl/](models/internvl/) | | **Molmo** | [models/molmo/](models/molmo/) | +Example command for Qwen3-VL reranker: +```bash +python models/qwen3vl/reranker/qwen3_vl_reranker.py +``` + ## Documentation - **Full Guide**: [VLM Documentation](../../docs/source/quick_start.md#vision-language-models) diff --git a/examples/image_text_to_text/models/qwen3vl/reranker/README.md b/examples/image_text_to_text/models/qwen3vl/reranker/README.md new file mode 100644 index 0000000000..a3e715478d --- /dev/null +++ b/examples/image_text_to_text/models/qwen3vl/reranker/README.md @@ -0,0 +1,52 @@ +# Qwen3-VL Reranker Inference + +This directory contains an AI100 example for running Qwen3-VL reranker models with QEfficient and printing per-document relevance scores. + +Supported models: +- `Qwen/Qwen3-VL-Reranker-2B` +- `Qwen/Qwen3-VL-Reranker-8B` + +## What this example does + +- Loads Qwen3-VL reranker from Hugging Face (or local snapshot path). +- Uses QEff dual-QPC execution (vision encoder + language model). +- Runs the same query against multiple text/image documents. +- Prints one score per document in input order. + +## Required package + +- `qwen-vl-utils>=0.0.14` + +```bash +pip install "qwen-vl-utils>=0.0.14" +``` + +## Script + +- `qwen3_vl_reranker.py` + +## Run + +```bash +python examples/image_text_to_text/models/qwen3vl/reranker/qwen3_vl_reranker.py \ + --model-name Qwen/Qwen3-VL-Reranker-2B +``` + +Or run with 8B: + +```bash +python examples/image_text_to_text/models/qwen3vl/reranker/qwen3_vl_reranker.py \ + --model-name Qwen/Qwen3-VL-Reranker-8B +``` + +With compile parameters: + +```bash +python examples/image_text_to_text/models/qwen3vl/reranker/qwen3_vl_reranker.py \ + --model-name Qwen/Qwen3-VL-Reranker-2B \ + --ctx-len 2048 \ + --num-cores 16 \ + --num-devices 1 \ + --compile-prefill-seq-len 4096 \ + --mxfp6-matmul +``` diff --git a/examples/image_text_to_text/models/qwen3vl/reranker/qwen3_vl_reranker.py b/examples/image_text_to_text/models/qwen3vl/reranker/qwen3_vl_reranker.py new file mode 100644 index 0000000000..2fdd225571 --- /dev/null +++ b/examples/image_text_to_text/models/qwen3vl/reranker/qwen3_vl_reranker.py @@ -0,0 +1,555 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse +import os +from typing import Dict, List, Tuple + +import numpy as np +import torch +from huggingface_hub import snapshot_download +from qwen_vl_utils import process_vision_info +from transformers import AutoConfig, AutoProcessor + +from QEfficient import QEFFAutoModelForImageTextToText +from QEfficient.generation.cloud_infer import QAICInferenceSession + +DEFAULT_MODEL_NAME = "Qwen/Qwen3-VL-Reranker-2B" +DEFAULT_CTX_LEN = 2048 +DEFAULT_NUM_CORES = 16 +DEFAULT_NUM_DEVICES = 1 + +# Max token budget used by this example's manual truncation/padding flow. +MAX_LENGTH = 8192 +# Pixel constraints used by Qwen3-VL preprocessing. +IMAGE_BASE_FACTOR = 16 +IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2 +MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR +MAX_PIXELS = 1280 * IMAGE_FACTOR * IMAGE_FACTOR +FPS = 1.0 + + +class QEffQwen3VLReranker: + @staticmethod + def _resolve_model_source(model_name_or_path: str) -> str: + """Return a local model path when given an HF repo id. + + Why: + Some transformers versions can fail when resolving chat templates from + repo-id mode for this model. Using a local snapshot path avoids that path. + """ + if os.path.isdir(model_name_or_path): + return model_name_or_path + return snapshot_download(repo_id=model_name_or_path) + + def __init__( + self, + model_name_or_path: str = DEFAULT_MODEL_NAME, + ctx_len: int = DEFAULT_CTX_LEN, + num_cores: int = DEFAULT_NUM_CORES, + num_devices: int = DEFAULT_NUM_DEVICES, + mxfp6_matmul: bool = False, + compile_prefill_seq_len: int = None, + ): + """Initialize the AI100-only reranker wrapper. + + This loads: + - HF config/processor for prompt and multimodal preprocessing. + - QEFF dual-QPC model wrapper (vision encoder + language decoder). + - Token ids for "yes"/"no" used to compute reranker scores. + + Parameters + ---------- + model_name_or_path: + HF model id or local snapshot path. + """ + self.model_name_or_path = model_name_or_path + self.model_source = self._resolve_model_source(model_name_or_path) + self.ctx_len = ctx_len + self.num_cores = num_cores + self.num_devices = num_devices + self.mxfp6_matmul = mxfp6_matmul + self.compile_prefill_seq_len = compile_prefill_seq_len + self.max_length = MAX_LENGTH + self.fps = FPS + + # Use local snapshot for stable processor/chat-template loading. + config = AutoConfig.from_pretrained(self.model_source, trust_remote_code=True, padding=True) + if hasattr(config, "use_cache"): + config.use_cache = True + if hasattr(config, "text_config") and hasattr(config.text_config, "use_cache"): + config.text_config.use_cache = True + + self.processor = AutoProcessor.from_pretrained(self.model_source, trust_remote_code=True, padding=True) + self.model = QEFFAutoModelForImageTextToText.from_pretrained( + self.model_source, + kv_offload=True, + trust_remote_code=True, + config=config, + ) + + self.yes_token_id, self.no_token_id = self._get_yes_no_token_ids(self.processor.tokenizer) + self._compiled_qpc_paths = None + self._compiled_prefill_seq_len = 0 + self._compiled_height = None + self._compiled_width = None + + @staticmethod + def _get_yes_no_token_ids(tokenizer) -> Tuple[int, int]: + """Resolve tokenizer ids for the exact tokens 'yes' and 'no'.""" + vocab = tokenizer.get_vocab() + if "yes" not in vocab or "no" not in vocab: + raise ValueError("Could not resolve tokenizer ids for exact tokens 'yes' and 'no'.") + return vocab["yes"], vocab["no"] + + @staticmethod + def _score_from_logits(logits, yes_token_id: int, no_token_id: int) -> float: + """Convert model logits into a reranker relevance score. + + Score formula: + sigmoid(logit_yes - logit_no) + """ + # Convert runtime output to torch and use final-token logits. + logits_tensor = torch.from_numpy(logits) if isinstance(logits, np.ndarray) else logits.detach().cpu() + if logits_tensor.ndim == 3: + logits_tensor = logits_tensor[:, -1, :] + # Binary relevance score from yes/no logit gap. + score = torch.sigmoid(logits_tensor[:, yes_token_id] - logits_tensor[:, no_token_id]) + return float(score[0].item()) + + @staticmethod + def _truncate_tokens_optimized(tokens: List[int], max_length: int, special_tokens: List[int]) -> List[int]: + """Truncate while preserving all special tokens in sequence order.""" + if len(tokens) <= max_length: + return tokens + + # Preserve all special/control tokens and trim only non-special tokens. + special_tokens_set = set(special_tokens) + num_special = sum(1 for token in tokens if token in special_tokens_set) + num_non_special_to_keep = max_length - num_special + + final_tokens = [] + non_special_kept_count = 0 + for token in tokens: + if token in special_tokens_set: + final_tokens.append(token) + elif non_special_kept_count < num_non_special_to_keep: + final_tokens.append(token) + non_special_kept_count += 1 + return final_tokens + + def _format_mm_content(self, text, image, video, prefix: str) -> List[Dict]: + """Build one multimodal content block (prefix + optional image + optional text).""" + # Prefix helps the model distinguish query vs document sections. + content = [{"type": "text", "text": prefix}] + + if not text and not image and not video: + content.append({"type": "text", "text": "NULL"}) + return content + + if video: + raise ValueError("Video input is not supported in this AI100-only example.") + + if image: + # Convert local paths to file:// URIs for the processor. + if isinstance(image, str): + image_content = image if image.startswith(("http", "oss")) else "file://" + image + else: + image_content = image + content.append( + { + "type": "image", + "image": image_content, + "min_pixels": MIN_PIXELS, + "max_pixels": MAX_PIXELS, + } + ) + + if text: + content.append({"type": "text", "text": text}) + + return content + + def _format_mm_instruction(self, instruction: str, query: Dict, document: Dict) -> List[Dict]: + """Create the chat payload for one query-document pair.""" + # Prompt shape follows the HF reranker reference format. + contents = [{"type": "text", "text": ": " + instruction}] + + contents.extend( + self._format_mm_content( + query.get("text"), + query.get("image"), + query.get("video"), + prefix=":", + ) + ) + contents.extend( + self._format_mm_content( + document.get("text"), + document.get("image"), + document.get("video"), + prefix="\n:", + ) + ) + + return [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": ( + "Judge whether the Document meets the requirements based on the Query and the Instruct " + 'provided. Note that the answer can only be "yes" or "no".' + ), + } + ], + }, + {"role": "user", "content": contents}, + ] + + def _tokenize_pair(self, pair: List[Dict]) -> Dict: + """Tokenize a query-document pair with the exact HF multimodal pipeline.""" + # Processor expects list-of-conversations. + pairs = [pair] + text = self.processor.apply_chat_template(pairs, tokenize=False, add_generation_prompt=True) + + # Build image/video tensors + metadata for processor inputs. + images, videos, video_kwargs = process_vision_info( + pairs, + image_patch_size=16, + return_video_kwargs=True, + return_video_metadata=True, + ) + + if videos is not None: + videos, video_metadatas = zip(*videos) + videos = list(videos) + video_metadatas = list(video_metadatas) + else: + video_metadatas = None + + inputs = self.processor( + text=text, + images=images, + videos=videos, + video_metadata=video_metadatas, + truncation=False, + padding=False, + do_resize=False, + **video_kwargs, + ) + + # Apply custom truncation preserving trailing template control tokens. + for i, input_ids in enumerate(inputs["input_ids"]): + inputs["input_ids"][i] = ( + self._truncate_tokens_optimized( + input_ids[:-5], + self.max_length, + self.processor.tokenizer.all_special_ids, + ) + + input_ids[-5:] + ) + + # Re-pad through tokenizer utilities so masks align with token ids. + padded = self.processor.tokenizer.pad( + {"input_ids": inputs["input_ids"]}, + padding=True, + return_tensors="pt", + max_length=self.max_length, + ) + for key in padded: + inputs[key] = padded[key] + + if "pixel_values" in inputs: + # Keep pixels fp32 before explicit cast to fp16 during vision run. + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + + return inputs + + def _prepare_inputs(self, tokenized_inputs: Dict, prefill_seq_len: int = None): + """Prepare model inputs for dual-QPC prefill execution.""" + # True prompt length before compile-aligned padding. + runtime_prompt_len = int(tokenized_inputs["input_ids"].shape[1]) + effective_prefill = runtime_prompt_len if prefill_seq_len is None else prefill_seq_len + if effective_prefill < runtime_prompt_len: + raise ValueError( + f"prefill_seq_len ({effective_prefill}) must be >= runtime prompt length ({runtime_prompt_len})." + ) + + # Let model helper compute position_ids and multimodal placement. + prepared_inputs = self.model.model.prepare_inputs_for_generation( + inputs=tokenized_inputs, + prefill_seq_len=effective_prefill, + batch_size=1, + ) + + # Normalize image_grid_thw to the shape consumed by compiled path. + if "image_grid_thw" in prepared_inputs and prepared_inputs["image_grid_thw"].ndim == 2: + thw = prepared_inputs["image_grid_thw"][0] + t, h, w = int(thw[0].item()), int(thw[1].item()), int(thw[2].item()) + prepared_inputs["image_grid_thw"] = torch.zeros((1, t, h, w), dtype=thw.dtype) + + if "pixel_values" in prepared_inputs: + prepared_inputs["pixel_values"] = prepared_inputs["pixel_values"].to(torch.float32) + + return prepared_inputs, runtime_prompt_len + + def _ensure_compiled(self, prefill_seq_len: int, height: int, width: int): + """Compile QPCs if needed, otherwise reuse cached compiled artifacts.""" + # Reuse previously compiled artifacts whenever shapes are compatible. + if ( + self._compiled_qpc_paths is not None + and prefill_seq_len <= self._compiled_prefill_seq_len + and height == self._compiled_height + and width == self._compiled_width + ): + return + + reuse_vision_qpc = ( + self._compiled_qpc_paths is not None and height == self._compiled_height and width == self._compiled_width + ) + + # Compile one max prefill specialization and optionally skip vision recompile. + compiled_paths = self.model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=self.ctx_len, + img_size=max(height, width), + height=height, + width=width, + num_cores=self.num_cores, + num_devices=self.num_devices, + mxfp6_matmul=self.mxfp6_matmul, + # vision_embed_fp32=True, + skip_vision=reuse_vision_qpc, + ) + if reuse_vision_qpc: + compiled_paths["vision_qpc_path"] = self._compiled_qpc_paths["vision_qpc_path"] + + self._compiled_qpc_paths = compiled_paths + self._compiled_prefill_seq_len = prefill_seq_len + self._compiled_height = height + self._compiled_width = width + + @staticmethod + def _zero_vision_outputs(vision_outputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """Create zero-valued placeholders matching vision output buffers.""" + return {name: np.zeros_like(value) for name, value in vision_outputs.items()} + + def _run_ai100_vision(self, prepared_inputs) -> Dict[str, np.ndarray]: + """Run the compiled vision encoder QPC and return retained-state buffers.""" + if "pixel_values" not in prepared_inputs or "image_grid_thw" not in prepared_inputs: + raise ValueError("Missing pixel_values/image_grid_thw for vision execution.") + + # Vision session produces retained states consumed by language session. + vision_session = QAICInferenceSession(self._compiled_qpc_paths["vision_qpc_path"]) + vision_outputs = vision_session.run( + { + # Vision qpc expects fp16 pixels + int64 grid coordinates. + "pixel_values": prepared_inputs["pixel_values"].detach().cpu().numpy().astype(np.float16), + "image_grid_thw": prepared_inputs["image_grid_thw"].detach().cpu().numpy().astype(np.int64), + } + ) + vision_session.deactivate() + return vision_outputs + + def _run_ai100_prefill(self, prepared_inputs, vision_template: Dict[str, np.ndarray]) -> np.ndarray: + """Run one prefill pass on AI100 language QPC and return logits.""" + # Match runtime input to compiled prefill length. + prefill_len = prepared_inputs["position_ids"].shape[-1] + input_ids = prepared_inputs["input_ids"] + if input_ids.shape[1] < prefill_len: + pad = torch.full( + (input_ids.shape[0], prefill_len - input_ids.shape[1]), + 1, + dtype=input_ids.dtype, + device=input_ids.device, + ) + input_ids = torch.cat([input_ids, pad], dim=1) + else: + input_ids = input_ids[:, :prefill_len] + + position_ids = prepared_inputs["position_ids"][..., :prefill_len] + + # For text-only docs, inject zeroed retained states with matching shapes. + if "pixel_values" in prepared_inputs and "image_grid_thw" in prepared_inputs: + vision_outputs = self._run_ai100_vision(prepared_inputs) + else: + vision_outputs = self._zero_vision_outputs(vision_template) + + # Skip past/retained buffers and run only required prefill inputs. + lang_session = QAICInferenceSession(self._compiled_qpc_paths["lang_qpc_path"]) + lang_session.skip_buffers( + [ + name + for name in lang_session.input_names + lang_session.output_names + if name.startswith("past_") or name.endswith("_RetainedState") + ] + ) + lang_session.set_buffers(vision_outputs) + outputs = lang_session.run( + { + # image_idx selects the vision buffer slot for this request. + "input_ids": input_ids.detach().cpu().numpy().astype(np.int64), + "position_ids": position_ids.detach().cpu().numpy().astype(np.int64), + "image_idx": np.zeros((1, 1), dtype=np.int64), + } + ) + lang_session.deactivate() + return outputs["logits"] + + def process(self, inputs: Dict) -> List[float]: + """Score all documents for one query on AI100. + + High-level flow: + 1) Build model-ready query-document pairs. + 2) Find max prompt/image shape across all docs. + 3) Compile once at max shape (single stable specialization). + 4) Run prefill per doc and convert logits -> score. + """ + # Unpack user payload. + instruction = inputs["instruction"] + query = inputs.get("query", {}) + documents = inputs.get("documents", []) + + # Collect per-document tokenized contexts first so we can compile once + # with the largest prompt/image shape required by this request. + prepared_contexts = [] + max_prompt_len = 0 + max_grid_h = 22 + max_grid_w = 34 + + # Build each pair in the exact chat-template format expected by the model. + for document in documents: + pair = self._format_mm_instruction(instruction, query, document) + tokenized = self._tokenize_pair(pair) + runtime_prompt_len = int(tokenized["input_ids"].shape[1]) + + # Track the max image grid (H, W) seen so compile dimensions can + # handle all documents in this batch. + if "image_grid_thw" in tokenized and tokenized["image_grid_thw"].numel() > 0: + grid = tokenized["image_grid_thw"] + max_grid_h = max(max_grid_h, int(grid[..., 1].max().item())) + max_grid_w = max(max_grid_w, int(grid[..., 2].max().item())) + + prepared_contexts.append( + { + "tokenized": tokenized, + "runtime_prompt_len": runtime_prompt_len, + } + ) + max_prompt_len = max(max_prompt_len, runtime_prompt_len) + + # Empty documents list => no scores. + if max_prompt_len == 0: + return [] + + # Convert max grid to compile-time pixel dimensions using model patch size. + patch_size = int(self.model.model.config.vision_config.patch_size) + compile_height = max_grid_h * patch_size + compile_width = max_grid_w * patch_size + + # Compile/reuse a single language specialization and prepare all requests + # to that same prefill length to avoid per-document recompiles. + target_prefill_seq_len = max_prompt_len + if self.compile_prefill_seq_len is not None: + if self.compile_prefill_seq_len < max_prompt_len: + raise ValueError( + f"--compile-prefill-seq-len ({self.compile_prefill_seq_len}) must be >= " + f"max runtime prompt length ({max_prompt_len})." + ) + target_prefill_seq_len = self.compile_prefill_seq_len + + self._ensure_compiled(target_prefill_seq_len, compile_height, compile_width) + + # Prepare all documents to the same prefill length used at compile time. + prepared_contexts_with_prefill = [] + vision_template = None + for ctx in prepared_contexts: + prepared_inputs, _ = self._prepare_inputs(ctx["tokenized"], prefill_seq_len=target_prefill_seq_len) + prepared_contexts_with_prefill.append({"prepared_inputs": prepared_inputs}) + + # Capture one real vision-output template so text-only docs can reuse + # zero-valued buffers with exact matching shapes. + if vision_template is None and "pixel_values" in prepared_inputs and "image_grid_thw" in prepared_inputs: + vision_template = self._run_ai100_vision(prepared_inputs) + + # This example currently expects at least one image document to establish + # retained-state buffer shapes for mixed image/text batches. + if vision_template is None: + raise ValueError("At least one image document is required to initialize AI100 vision buffers.") + + # Run language prefill and compute scalar score per document. + scores = [] + for ctx in prepared_contexts_with_prefill: + logits = self._run_ai100_prefill( + ctx["prepared_inputs"], + vision_template=vision_template, + ) + # Reranker score = sigmoid(logit_yes - logit_no). + score = self._score_from_logits(logits, self.yes_token_id, self.no_token_id) + scores.append(score) + + return scores + + +def main(): + # Keep CLI simple: just allow model id/path override. + parser = argparse.ArgumentParser(description="Qwen3-VL reranker example.") + parser.add_argument("--model-name", type=str, default=DEFAULT_MODEL_NAME) + parser.add_argument("--ctx-len", type=int, default=DEFAULT_CTX_LEN, help="Context length used at compile time.") + parser.add_argument("--num-cores", type=int, default=DEFAULT_NUM_CORES, help="Number of AI100 cores.") + parser.add_argument("--num-devices", type=int, default=DEFAULT_NUM_DEVICES, help="Number of AI100 devices.") + parser.add_argument( + "--mxfp6-matmul", + action="store_true", + help="Enable MXFP6 matmul during compile (default: disabled).", + ) + parser.add_argument( + "--compile-prefill-seq-len", + type=int, + default=None, + help=( + "Optional fixed prefill sequence length for compile/padding. " + "Must be >= max prompt length of the current request." + ), + ) + args = parser.parse_args() + + model = QEffQwen3VLReranker( + model_name_or_path=args.model_name, + ctx_len=args.ctx_len, + num_cores=args.num_cores, + num_devices=args.num_devices, + mxfp6_matmul=args.mxfp6_matmul, + compile_prefill_seq_len=args.compile_prefill_seq_len, + ) + + # Example input payload matching the HF reranker schema. + inputs = { + "instruction": "Retrieve images or text relevant to the user's query.", + "query": {"text": "A woman playing with her dog on a beach at sunset."}, + "documents": [ + { + "text": "A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset, as the dog offers its paw in a heartwarming display of companionship and trust." + }, + {"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"}, + { + "text": "A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset, as the dog offers its paw in a heartwarming display of companionship and trust.", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + ], + "fps": 1.0, + } + + # Print one score per document in the same order as inputs["documents"]. + scores = model.process(inputs) + print(scores) + + +if __name__ == "__main__": + main() diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index c2ec8b2add..6c7fdd01a4 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -107,49 +107,63 @@ pipeline { ''' } } - } - - } - } - stage('QAIC FEATURE') { - when {expression { params.RUN_QAIC_FEATURE }} - steps { - timeout(time: params.TEST_PROFILE == 'full_layers_model' ? 0 : 120, unit: params.TEST_PROFILE == 'full_layers_model' ? 'HOURS' : 'MINUTES') { - sh ''' - sudo docker exec ${BUILD_TAG} bash -c " - cd /efficient-transformers && - . preflight_qeff/bin/activate && - mkdir -p $PWD/Non_qaic_feature && - export TOKENIZERS_PARALLELISM=false && - export QEFF_HOME=$PWD/Non_qaic_feature && - pytest tests -m '(on_qaic) and (feature) and (not qnn) and ${TEST_FILTER}' --ignore tests/transformers/sampler --ignore tests/vllm --ignore tests/unit_test --junitxml=tests/tests_log2_feature.xml --durations=10 && - junitparser merge tests/tests_log2_feature.xml tests/tests_log.xml && - deactivate" - ''' - } - } + } + stage('QAIC Feature Tests') { + steps { + timeout(time: 80, unit: 'MINUTES') { + sh ''' + sudo docker exec ${BUILD_TAG} bash -c " + cd /efficient-transformers && + . preflight_qeff/bin/activate && + mkdir -p $PWD/Non_qaic_feature && + export TOKENIZERS_PARALLELISM=false && + export QEFF_HOME=$PWD/Non_qaic_feature && + pytest tests -m '(not cli) and (on_qaic) and (feature) and (not nightly) and (not multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --ignore tests/unit_test --ignore tests/transformers/sampler --junitxml=tests/tests_log2_feature.xml --durations=10 && + junitparser merge tests/tests_log2_feature.xml tests/tests_log.xml && + deactivate" + ''' + } + } + } + } + } + stage('QAIC MultiModal Tests') { + steps { + timeout(time: 120, unit: 'MINUTES') { + sh ''' + sudo docker exec ${BUILD_TAG} bash -c " + cd /efficient-transformers && + . preflight_qeff/bin/activate && + mkdir -p $PWD/Non_cli_qaic_multimodal && + export TOKENIZERS_PARALLELISM=false && + export QEFF_HOME=$PWD/Non_cli_qaic_multimodal && + pytest tests -m '(not cli) and (on_qaic) and (multimodal) and (not qnn) and (not finetune) and (not diffusion_models) and (not nightly)' --ignore tests/vllm --ignore tests/unit_test --ignore tests/transformers/models/image_text_to_text/test_reranker_mad.py --junitxml=tests/tests_log6.xml --durations=10 && + + junitparser merge tests/tests_log6.xml tests/tests_log.xml && + deactivate" + ''' + } + } } - stage('QAIC Multimodal') { - when {expression { params.RUN_QAIC_MM }} - steps { - timeout(time: params.TEST_PROFILE == 'full_layers_model' ? 0 : 180, unit: params.TEST_PROFILE == 'full_layers_model' ? 'HOURS' : 'MINUTES') { - sh ''' - sudo docker exec ${BUILD_TAG} bash -c " - cd /efficient-transformers && - . preflight_qeff/bin/activate && - mkdir -p $PWD/Non_cli_qaic_multimodal && - export TOKENIZERS_PARALLELISM=false && - export QEFF_HOME=$PWD/Non_cli_qaic_multimodal && - pytest tests -m '(multimodal) and (not qnn) and ${TEST_FILTER}' --ignore tests/vllm --ignore tests/unit_test --junitxml=tests/tests_log6.xml --durations=10 && - junitparser merge tests/tests_log6.xml tests/tests_log.xml && - deactivate" - ''' - } - } + stage('QAIC Reranker Tests') { + steps { + timeout(time: 60, unit: 'MINUTES') { + sh ''' + sudo docker exec ${BUILD_TAG} bash -c " + cd /efficient-transformers && + . preflight_qeff/bin/activate && + mkdir -p $PWD/Non_cli_qaic_reranker && + export TOKENIZERS_PARALLELISM=false && + export QEFF_HOME=$PWD/Non_cli_qaic_reranker && + export QEFF_RERANKER_DOC_LIMIT=1 && + pytest -q tests/transformers/models/image_text_to_text/test_reranker_mad.py --maxfail=1 --junitxml=tests/tests_log_reranker.xml --durations=10 && + junitparser merge tests/tests_log_reranker.xml tests/tests_log.xml && + deactivate" + ''' + } + } } - - stage('Diffusion Models') { - when { expression { params.RUN_QAIC_DIFFUSION } } + stage('QAIC Diffusion Models Tests') { steps { timeout(time: 120, unit: 'MINUTES') { sh ''' diff --git a/tests/configs/image_text_model_configs.json b/tests/configs/image_text_model_configs.json index dad5112732..956f0fcd45 100644 --- a/tests/configs/image_text_model_configs.json +++ b/tests/configs/image_text_model_configs.json @@ -5,7 +5,7 @@ "model_type": "llava", "batch_size": 1, "prompt_len": 784, - "ctx_len": 1024, + "ctx_len": 2048, "img_size": 336, "img_url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg", "text_prompt": "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud", @@ -607,7 +607,37 @@ "text_prompt": "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud", "num_layers": 1, "additional_params": { + } + } + ], + "image_text_reranker_models": [ + { + "model_name": "Qwen/Qwen3-VL-Reranker-2B", + "model_type": "qwen3_vl", + "batch_size": 1, + "prompt_len": 128, + "ctx_len": 1024, + "img_size": 1540, + "img_url": "https://picsum.photos/id/237/536/354", + "instruction": "Retrieve candidates relevant to the query.", + "query_text": "A woman playing with her dog on a beach at sunset.", + "document_text": "A woman and her dog spend time together on a beach during sunset.", + "num_layers": 1, + "additional_params": {} + }, + { + "model_name": "Qwen/Qwen3-VL-Reranker-8B", + "model_type": "qwen3_vl", + "batch_size": 1, + "prompt_len": 128, + "ctx_len": 1024, + "img_size": 1540, + "img_url": "https://picsum.photos/id/237/536/354", + "instruction": "Retrieve candidates relevant to the query.", + "query_text": "A woman playing with her dog on a beach at sunset.", + "document_text": "A woman and her dog spend time together on a beach during sunset.", + "num_layers": 1, + "additional_params": {} } - } ] -} \ No newline at end of file +} diff --git a/tests/transformers/models/image_text_to_text/test_reranker_mad.py b/tests/transformers/models/image_text_to_text/test_reranker_mad.py new file mode 100644 index 0000000000..3a6497b520 --- /dev/null +++ b/tests/transformers/models/image_text_to_text/test_reranker_mad.py @@ -0,0 +1,455 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import json +import os +from typing import Dict, List, Tuple + +import numpy as np +import pytest +import torch +from huggingface_hub import snapshot_download +from qwen_vl_utils import process_vision_info +from transformers import AutoConfig, AutoProcessor + +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForImageTextToText +from QEfficient.utils.test_utils import load_vlm_model, set_num_layers_vlm + +CONFIG_PATH = "tests/configs/image_text_model_configs.json" + +PT_AI100_MAD_MAX = 5e-3 +MAX_LENGTH = 8192 +RERANKER_DOC_LIMIT = int(os.getenv("QEFF_RERANKER_DOC_LIMIT", "0")) + +IMAGE_BASE_FACTOR = 16 +IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2 +MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR +MAX_PIXELS = 1280 * IMAGE_FACTOR * IMAGE_FACTOR + +EXAMPLE_INPUTS = { + "instruction": "Retrieve relevant content.", + "query": {"text": "dog on beach"}, + "documents": [ + {"image": "https://picsum.photos/id/237/536/354"}, + {"text": "A dog running on the beach."}, + ], +} + +with open(CONFIG_PATH, "r") as f: + config_data = json.load(f) + reranker_models = config_data["image_text_reranker_models"] + +test_reranker_models = [model_config["model_name"] for model_config in reranker_models] +reranker_model_config_dict = {model["model_name"]: model for model in reranker_models} + + +def _resolve_model_source(model_name_or_path: str) -> str: + if os.path.isdir(model_name_or_path): + return model_name_or_path + return snapshot_download(repo_id=model_name_or_path) + + +def _format_mm_content(text, image, video, prefix: str) -> List[Dict]: + content = [{"type": "text", "text": prefix}] + + if not text and not image and not video: + content.append({"type": "text", "text": "NULL"}) + return content + + if video: + raise ValueError("Video input is not supported in this test.") + + if image: + if isinstance(image, str): + image_content = image if image.startswith(("http", "oss")) else "file://" + image + else: + image_content = image + content.append( + { + "type": "image", + "image": image_content, + "min_pixels": MIN_PIXELS, + "max_pixels": MAX_PIXELS, + } + ) + + if text: + content.append({"type": "text", "text": text}) + + return content + + +def _format_mm_instruction(instruction: str, query: Dict, document: Dict) -> List[Dict]: + contents = [{"type": "text", "text": ": " + instruction}] + + contents.extend( + _format_mm_content( + query.get("text"), + query.get("image"), + query.get("video"), + prefix=":", + ) + ) + contents.extend( + _format_mm_content( + document.get("text"), + document.get("image"), + document.get("video"), + prefix="\n:", + ) + ) + + return [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": ( + "Judge whether the Document meets the requirements based on the Query and the Instruct " + 'provided. Note that the answer can only be "yes" or "no".' + ), + } + ], + }, + {"role": "user", "content": contents}, + ] + + +def _truncate_tokens_optimized(tokens: List[int], max_length: int, special_tokens: List[int]) -> List[int]: + if len(tokens) <= max_length: + return tokens + + special_tokens_set = set(special_tokens) + num_special = sum(1 for token in tokens if token in special_tokens_set) + num_non_special_to_keep = max_length - num_special + + final_tokens = [] + non_special_kept_count = 0 + for token in tokens: + if token in special_tokens_set: + final_tokens.append(token) + elif non_special_kept_count < num_non_special_to_keep: + final_tokens.append(token) + non_special_kept_count += 1 + return final_tokens + + +def _tokenize_pair(processor, pair: List[Dict]) -> Dict: + pairs = [pair] + text = processor.apply_chat_template(pairs, tokenize=False, add_generation_prompt=True) + + images, videos, video_kwargs = process_vision_info( + pairs, + image_patch_size=16, + return_video_kwargs=True, + return_video_metadata=True, + ) + + if videos is not None: + videos, video_metadatas = zip(*videos) + videos = list(videos) + video_metadatas = list(video_metadatas) + else: + video_metadatas = None + + inputs = processor( + text=text, + images=images, + videos=videos, + video_metadata=video_metadatas, + truncation=False, + padding=False, + do_resize=False, + **video_kwargs, + ) + + for i, input_ids in enumerate(inputs["input_ids"]): + inputs["input_ids"][i] = ( + _truncate_tokens_optimized( + input_ids[:-5], + MAX_LENGTH, + processor.tokenizer.all_special_ids, + ) + + input_ids[-5:] + ) + + padded = processor.tokenizer.pad( + {"input_ids": inputs["input_ids"]}, + padding=True, + return_tensors="pt", + max_length=MAX_LENGTH, + ) + for key in padded: + inputs[key] = padded[key] + + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + + return inputs + + +def _get_yes_no_token_ids(tokenizer) -> Tuple[int, int]: + vocab = tokenizer.get_vocab() + if "yes" not in vocab or "no" not in vocab: + raise ValueError("Could not resolve tokenizer ids for exact tokens 'yes' and 'no'.") + return vocab["yes"], vocab["no"] + + +def _score_from_logits(logits, yes_token_id: int, no_token_id: int) -> np.ndarray: + if isinstance(logits, np.ndarray): + logits_tensor = torch.from_numpy(logits) + else: + logits_tensor = logits.detach().cpu() + + if logits_tensor.ndim == 3: + logits_tensor = logits_tensor[:, -1, :] + elif logits_tensor.ndim != 2: + raise ValueError(f"Unsupported logits rank for score conversion: {logits_tensor.ndim}") + + score = torch.sigmoid(logits_tensor[:, yes_token_id] - logits_tensor[:, no_token_id]) + return score.detach().cpu().numpy().astype(np.float64) + + +def _score_from_last_hidden(last_hidden_state: torch.Tensor, score_linear: torch.nn.Linear) -> np.ndarray: + score = torch.sigmoid(score_linear(last_hidden_state[:, -1])).squeeze(-1) + return score.detach().cpu().numpy().astype(np.float64) + + +def _make_score_linear(model_hf, yes_token_id: int, no_token_id: int) -> torch.nn.Linear: + lm_head_weights = model_hf.lm_head.weight.data + weight_yes = lm_head_weights[yes_token_id] + weight_no = lm_head_weights[no_token_id] + + linear_layer = torch.nn.Linear(weight_yes.shape[0], 1, bias=False) + with torch.no_grad(): + linear_layer.weight[0] = weight_yes - weight_no + return linear_layer.eval() + + +def _mad_stats(reference: np.ndarray, candidate: np.ndarray) -> Tuple[float, float]: + diff = np.abs(reference - candidate) + return float(np.mean(diff)), float(np.max(diff)) + + +def _prepare_qeff_inputs(qeff_model, tokenized_inputs: Dict, prefill_seq_len: int = None): + runtime_prompt_len = int(tokenized_inputs["input_ids"].shape[1]) + effective_prefill_seq_len = runtime_prompt_len if prefill_seq_len is None else prefill_seq_len + if effective_prefill_seq_len < runtime_prompt_len: + raise ValueError( + f"prefill_seq_len ({effective_prefill_seq_len}) must be >= runtime prompt length ({runtime_prompt_len})." + ) + + prepared_inputs = qeff_model.model.prepare_inputs_for_generation( + inputs=tokenized_inputs, + prefill_seq_len=effective_prefill_seq_len, + batch_size=1, + ) + + if "image_grid_thw" in prepared_inputs and prepared_inputs["image_grid_thw"].ndim == 2: + thw = prepared_inputs["image_grid_thw"][0] + t, h, w = int(thw[0].item()), int(thw[1].item()), int(thw[2].item()) + prepared_inputs["image_grid_thw"] = torch.zeros((1, t, h, w), dtype=thw.dtype) + + if "pixel_values" in prepared_inputs: + prepared_inputs["pixel_values"] = prepared_inputs["pixel_values"].to(torch.float32) + + return prepared_inputs, runtime_prompt_len + + +def _zero_vision_outputs(vision_outputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + return {name: np.zeros_like(value) for name, value in vision_outputs.items()} + + +def _run_ai100_vision(vision_qpc_path: str, prepared_inputs) -> Dict[str, np.ndarray]: + vision_session = QAICInferenceSession(vision_qpc_path) + vision_inputs = { + "pixel_values": prepared_inputs["pixel_values"].detach().cpu().numpy().astype(np.float16), + "image_grid_thw": prepared_inputs["image_grid_thw"].detach().cpu().numpy().astype(np.int64), + } + vision_outputs = vision_session.run(vision_inputs) + vision_session.deactivate() + return vision_outputs + + +def _run_ai100_prefill(qpc_paths, prepared_inputs, vision_template): + if not isinstance(qpc_paths, dict): + raise ValueError("Expected qpc_paths to be a dict with vision/lang QPC keys.") + + vision_qpc_path = qpc_paths.get("vision_qpc_path") + lang_qpc_path = qpc_paths.get("lang_qpc_path") + if vision_qpc_path is None or lang_qpc_path is None: + raise ValueError("Missing vision_qpc_path/lang_qpc_path in compiled QPC outputs.") + + prefill_len = prepared_inputs["position_ids"].shape[-1] + input_ids = prepared_inputs["input_ids"] + if input_ids.shape[1] < prefill_len: + pad = torch.full( + (input_ids.shape[0], prefill_len - input_ids.shape[1]), + 1, + dtype=input_ids.dtype, + device=input_ids.device, + ) + input_ids = torch.cat([input_ids, pad], dim=1) + else: + input_ids = input_ids[:, :prefill_len] + position_ids = prepared_inputs["position_ids"][..., :prefill_len] + + if "pixel_values" in prepared_inputs and "image_grid_thw" in prepared_inputs: + vision_outputs = _run_ai100_vision(vision_qpc_path, prepared_inputs) + else: + vision_outputs = _zero_vision_outputs(vision_template) + + lang_session = QAICInferenceSession(lang_qpc_path) + lang_session.skip_buffers( + [ + name + for name in lang_session.input_names + lang_session.output_names + if name.startswith("past_") or name.endswith("_RetainedState") + ] + ) + lang_session.set_buffers(vision_outputs) + lang_inputs = { + "input_ids": input_ids.detach().cpu().numpy().astype(np.int64), + "position_ids": position_ids.detach().cpu().numpy().astype(np.int64), + "image_idx": np.zeros((1, 1), dtype=np.int64), + } + outputs = lang_session.run(lang_inputs) + lang_session.deactivate() + return outputs["logits"] + + +@pytest.mark.on_qaic +@pytest.mark.multimodal +@pytest.mark.regular +@pytest.mark.parametrize("model_name", test_reranker_models) +def test_qwen3_vl_reranker_mad_parity(model_name): + torch.manual_seed(42) + model_cfg = reranker_model_config_dict[model_name] + model_source = _resolve_model_source(model_name) + + config = AutoConfig.from_pretrained(model_source, trust_remote_code=True, padding=True) + config = set_num_layers_vlm(config, n_layer=model_cfg["num_layers"]) + if hasattr(config, "use_cache"): + config.use_cache = True + if hasattr(config, "text_config") and hasattr(config.text_config, "use_cache"): + config.text_config.use_cache = True + + model_hf = load_vlm_model(config) + model_hf.eval() + + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_source, + kv_offload=True, + config=config, + ) + processor = AutoProcessor.from_pretrained(model_source, trust_remote_code=True, padding=True) + + yes_token_id, no_token_id = _get_yes_no_token_ids(processor.tokenizer) + score_linear = _make_score_linear(model_hf, yes_token_id, no_token_id).to(next(model_hf.parameters()).device) + score_linear = score_linear.to(dtype=next(model_hf.parameters()).dtype) + + doc_contexts = [] + max_prompt_len = 0 + max_grid_h = 22 + max_grid_w = 34 + + hf_scores_list = [] + + documents = EXAMPLE_INPUTS["documents"] + if RERANKER_DOC_LIMIT > 0: + documents = documents[:RERANKER_DOC_LIMIT] + + for document in documents: + pair = _format_mm_instruction( + instruction=EXAMPLE_INPUTS["instruction"], + query=EXAMPLE_INPUTS["query"], + document=document, + ) + tokenized = _tokenize_pair(processor, pair) + runtime_prompt_len = int(tokenized["input_ids"].shape[1]) + + hf_inputs = {} + for key, value in tokenized.items(): + hf_inputs[key] = value.to(next(model_hf.parameters()).device) if torch.is_tensor(value) else value + with torch.no_grad(): + hf_last_hidden = model_hf.model(**hf_inputs).last_hidden_state + hf_score = _score_from_last_hidden(hf_last_hidden, score_linear)[0] + hf_scores_list.append(float(hf_score)) + + if "image_grid_thw" in tokenized and tokenized["image_grid_thw"].numel() > 0: + grid = tokenized["image_grid_thw"] + max_grid_h = max(max_grid_h, int(grid[..., 1].max().item())) + max_grid_w = max(max_grid_w, int(grid[..., 2].max().item())) + + doc_contexts.append( + { + "tokenized": tokenized, + } + ) + max_prompt_len = max(max_prompt_len, runtime_prompt_len) + + patch_size = int(qeff_model.model.config.vision_config.patch_size) + compile_height = max_grid_h * patch_size + compile_width = max_grid_w * patch_size + + qpc_paths = qeff_model.compile( + img_size=max(compile_height, compile_width), + height=compile_height, + width=compile_width, + prefill_seq_len=max_prompt_len, + ctx_len=model_cfg["ctx_len"], + num_devices=1, + num_cores=16, + mxfp6_matmul=False, + ) + + ai100_scores_list = [] + + prepared_contexts = [] + vision_template_ai100 = None + for context in doc_contexts: + prepared_inputs, _ = _prepare_qeff_inputs( + qeff_model=qeff_model, + tokenized_inputs=context["tokenized"], + prefill_seq_len=max_prompt_len, + ) + prepared_contexts.append( + { + "prepared_inputs": prepared_inputs, + } + ) + if vision_template_ai100 is None and "pixel_values" in prepared_inputs and "image_grid_thw" in prepared_inputs: + vision_template_ai100 = _run_ai100_vision(qpc_paths["vision_qpc_path"], prepared_inputs) + + if vision_template_ai100 is None: + raise ValueError("Expected at least one image document to initialize vision templates.") + + for context in prepared_contexts: + prepared_inputs_runtime = context["prepared_inputs"] + ai100_logits = _run_ai100_prefill( + qpc_paths=qpc_paths, + prepared_inputs=prepared_inputs_runtime, + vision_template=vision_template_ai100, + ) + ai100_score = _score_from_logits(ai100_logits, yes_token_id, no_token_id)[0] + ai100_scores_list.append(float(ai100_score)) + + hf_scores = np.array(hf_scores_list, dtype=np.float64) + ai100_scores = np.array(ai100_scores_list, dtype=np.float64) + + print(f"[SCORES] PyTorch(original): {hf_scores.tolist()}") + print(f"[SCORES] AI100: {ai100_scores.tolist()}") + + pt_ai100_mad_mean, pt_ai100_mad_max = _mad_stats(hf_scores, ai100_scores) + print(f"[MAD] PyTorch(original) vs AI100: mean={pt_ai100_mad_mean:.6e}, max={pt_ai100_mad_max:.6e}") + assert pt_ai100_mad_max <= PT_AI100_MAD_MAX, ( + f"PyTorch(original) vs AI100 MAD max {pt_ai100_mad_max:.6e} " + f"exceeds threshold {PT_AI100_MAD_MAX:.6e}. " + f"Check tokenizer ids, prompt formatting, runtime prompt length slicing, and compile dimensions." + )