From 3eb106ac57c6dca309a40d385d51bdbd876cd9ad Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Mon, 20 Apr 2026 09:53:27 +0000 Subject: [PATCH] Add Qwen3-VL embedding support and CPU vs AI100 parity test Signed-off-by: Amit Raj --- .../models/qwen3_vl/modeling_qwen3_vl.py | 21 + .../qwen3vl/embedding/qwen3_vl_embedding.py | 434 ++++++++++++++++++ tests/configs/image_text_model_configs.json | 15 +- .../test_audio_embedding_models.py | 3 - .../test_speech_seq2seq_models.py | 2 - .../causal_lm_models/check_causal_models.py | 1 - .../test_causal_lm_blocking_hqkv.py | 6 - .../causal_lm_models/test_causal_lm_models.py | 4 - .../causal_lm_models/test_causal_lm_pl1.py | 6 - .../test_causal_tlm_models.py | 6 - .../causal_lm_models/test_fp16_causal_lm.py | 3 - .../image_text_to_text/test_custom_dtype.py | 2 - .../image_text_to_text/test_embedding_mad.py | 130 ++++++ .../test_causal_lm_blocking_subfunction.py | 3 - .../subfunction/test_subfunction_vlm.py | 4 - 15 files changed, 599 insertions(+), 41 deletions(-) create mode 100644 examples/image_text_to_text/models/qwen3vl/embedding/qwen3_vl_embedding.py create mode 100644 tests/transformers/models/image_text_to_text/test_embedding_mad.py diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 6d6c6b42d..1ae70b3d8 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -44,6 +44,19 @@ from QEfficient.utils.logging_utils import logger +def _should_export_embedding_output(module) -> bool: + for holder in (module, getattr(module, "model", None)): + if holder is None: + continue + qaic_config = getattr(holder, "qaic_config", None) + if isinstance(qaic_config, dict) and qaic_config.get("export_embedding", False): + return True + config = getattr(holder, "config", None) + if config is not None and getattr(config, "export_embedding", False): + return True + return False + + def qeff_apply_interleaved_mrope(freqs, mrope_section): """Apply interleaved MRoPE to 3D rotary embeddings. Reorganizes frequency layout from chunked [TTT...HHH...WWW] to @@ -742,6 +755,8 @@ def forward( hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] logits = self.model.lm_head(hidden_states) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + if _should_export_embedding_output(self): + return logits, vision_embeds, deepstack_features, image_idx, hidden_states, outputs.past_key_values return logits, vision_embeds, deepstack_features, image_idx, outputs.past_key_values @@ -839,6 +854,8 @@ def forward( hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] logits = self.lm_head(hidden_states) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + if _should_export_embedding_output(self): + return logits, image_embeds, image_idx, hidden_states, outputs.past_key_values return logits, image_embeds, image_idx, outputs.past_key_values def get_dummy_inputs( @@ -1162,11 +1179,15 @@ def get_output_names(self, kv_offload: bool = False): lang_output_names.insert(1, "vision_embeds_RetainedState") lang_output_names.insert(2, "image_idx_output") lang_output_names.insert(2, "deepstack_features_RetainedState") + if _should_export_embedding_output(self): + lang_output_names.insert(4, "embedding_output") output_names["vision"] = vision_output_names output_names["lang"] = lang_output_names else: lang_output_names.insert(1, "pixel_values_RetainedState") lang_output_names.insert(2, "image_idx_output") + if _should_export_embedding_output(self): + lang_output_names.insert(3, "embedding_output") return lang_output_names return output_names diff --git a/examples/image_text_to_text/models/qwen3vl/embedding/qwen3_vl_embedding.py b/examples/image_text_to_text/models/qwen3vl/embedding/qwen3_vl_embedding.py new file mode 100644 index 000000000..c14ffcdb4 --- /dev/null +++ b/examples/image_text_to_text/models/qwen3vl/embedding/qwen3_vl_embedding.py @@ -0,0 +1,434 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse +import os +import unicodedata +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from huggingface_hub import snapshot_download +from qwen_vl_utils import process_vision_info +from transformers import AutoConfig, AutoProcessor + +from QEfficient import QEFFAutoModelForImageTextToText +from QEfficient.generation.cloud_infer import QAICInferenceSession + +DEFAULT_MODEL_NAME = "Qwen/Qwen3-VL-Embedding-8B" +DEFAULT_CTX_LEN = 2048 +DEFAULT_NUM_CORES = 16 +DEFAULT_NUM_DEVICES = 1 +DEFAULT_INSTRUCTION = "Represent the user's input." +DEFAULT_NUM_HIDDEN_LAYERS = 36 +DEFAULT_VISION_DEPTH = 27 +DEFAULT_DEEPSTACK_INDEX = None + +MAX_LENGTH = 8192 +IMAGE_BASE_FACTOR = 16 +IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2 +MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR +MAX_PIXELS = 1800 * IMAGE_FACTOR * IMAGE_FACTOR + + +class QEffQwen3VLEmbedder: + @staticmethod + def _resolve_model_source(model_name_or_path: str) -> str: + if os.path.isdir(model_name_or_path): + return model_name_or_path + return snapshot_download(repo_id=model_name_or_path) + + def __init__( + self, + model_name_or_path: str = DEFAULT_MODEL_NAME, + ctx_len: int = DEFAULT_CTX_LEN, + num_cores: int = DEFAULT_NUM_CORES, + num_devices: int = DEFAULT_NUM_DEVICES, + mxfp6_matmul: bool = False, + compile_prefill_seq_len: Optional[int] = None, + num_hidden_layers: int = DEFAULT_NUM_HIDDEN_LAYERS, + vision_depth: int = DEFAULT_VISION_DEPTH, + deepstack_index: Optional[int] = DEFAULT_DEEPSTACK_INDEX, + ): + self.model_name_or_path = model_name_or_path + self.model_source = self._resolve_model_source(model_name_or_path) + self.ctx_len = ctx_len + self.num_cores = num_cores + self.num_devices = num_devices + self.mxfp6_matmul = mxfp6_matmul + self.compile_prefill_seq_len = compile_prefill_seq_len + + config = AutoConfig.from_pretrained(self.model_source, trust_remote_code=True, padding=True) + if hasattr(config, "use_cache"): + config.use_cache = True + if hasattr(config, "text_config") and hasattr(config.text_config, "use_cache"): + config.text_config.use_cache = True + if hasattr(config, "text_config") and num_hidden_layers > 0: + config.text_config.num_hidden_layers = num_hidden_layers + if hasattr(config, "vision_config"): + if hasattr(config.vision_config, "depth") and vision_depth > 0: + config.vision_config.depth = vision_depth + if hasattr(config.vision_config, "deepstack_visual_indexes"): + max_valid_idx = max(0, config.vision_config.depth - 1) + if deepstack_index is None: + default_indexes = [int(idx) for idx in config.vision_config.deepstack_visual_indexes] + clamped_defaults = [idx for idx in default_indexes if 0 <= idx <= max_valid_idx] + config.vision_config.deepstack_visual_indexes = ( + clamped_defaults if clamped_defaults else [max_valid_idx] + ) + else: + config.vision_config.deepstack_visual_indexes = [min(max(0, int(deepstack_index)), max_valid_idx)] + + # Enable optional hidden-state export from the QEff Qwen3-VL decoder. + config.export_embedding = True + + self.processor = AutoProcessor.from_pretrained(self.model_source, trust_remote_code=True, padding=True) + self.model = QEFFAutoModelForImageTextToText.from_pretrained( + self.model_source, + kv_offload=True, + trust_remote_code=True, + config=config, + qaic_config={"export_embedding": True}, + ) + + self._compiled_qpc_paths = None + self._compiled_prefill_seq_len = None + self._compiled_height = None + self._compiled_width = None + + @staticmethod + def _normalize_instruction(instruction: str) -> str: + instruction = instruction.strip() + if instruction and not unicodedata.category(instruction[-1]).startswith("P"): + instruction += "." + return instruction + + def format_model_input( + self, + text: Optional[str] = None, + image: Optional[Any] = None, + video: Optional[Any] = None, + instruction: Optional[str] = None, + ) -> List[Dict[str, Any]]: + resolved_instruction = self._normalize_instruction(instruction or DEFAULT_INSTRUCTION) + + content: List[Dict[str, Any]] = [] + conversation = [ + {"role": "system", "content": [{"type": "text", "text": resolved_instruction}]}, + {"role": "user", "content": content}, + ] + + if not text and not image and not video: + content.append({"type": "text", "text": "NULL"}) + return conversation + + if video: + raise ValueError("Video input is not supported in this example.") + + if image: + if isinstance(image, str): + image_content = image if image.startswith(("http://", "https://", "oss")) else "file://" + image + else: + image_content = image + content.append( + { + "type": "image", + "image": image_content, + "min_pixels": MIN_PIXELS, + "max_pixels": MAX_PIXELS, + } + ) + + if text: + content.append({"type": "text", "text": text}) + + return conversation + + def _tokenize_conversation(self, conversation: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + conversations = [conversation] + text = self.processor.apply_chat_template(conversations, tokenize=False, add_generation_prompt=True) + + images, videos, video_kwargs = process_vision_info( + conversations, + image_patch_size=16, + return_video_kwargs=True, + return_video_metadata=True, + ) + + if videos is not None: + videos, video_metadatas = zip(*videos) + videos = list(videos) + video_metadatas = list(video_metadatas) + else: + video_metadatas = None + + inputs = self.processor( + text=text, + images=images, + videos=videos, + video_metadata=video_metadatas, + truncation=True, + max_length=MAX_LENGTH, + padding=True, + do_resize=False, + return_tensors="pt", + **video_kwargs, + ) + + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + + return inputs + + @staticmethod + def _prepare_qeff_inputs(qeff_model, tokenized_inputs: Dict[str, torch.Tensor], prefill_seq_len: int): + runtime_prompt_len = int(tokenized_inputs["input_ids"].shape[1]) + if prefill_seq_len < runtime_prompt_len: + raise ValueError( + f"prefill_seq_len ({prefill_seq_len}) must be >= runtime prompt length ({runtime_prompt_len})." + ) + + prepared_inputs = qeff_model.model.prepare_inputs_for_generation( + inputs=tokenized_inputs, + prefill_seq_len=prefill_seq_len, + batch_size=1, + ) + + if "image_grid_thw" in prepared_inputs and prepared_inputs["image_grid_thw"].ndim == 2: + thw = prepared_inputs["image_grid_thw"][0] + t, h, w = int(thw[0].item()), int(thw[1].item()), int(thw[2].item()) + prepared_inputs["image_grid_thw"] = torch.zeros((1, t, h, w), dtype=thw.dtype) + + if "pixel_values" in prepared_inputs: + prepared_inputs["pixel_values"] = prepared_inputs["pixel_values"].to(torch.float32) + + return prepared_inputs + + @staticmethod + def _zero_vision_outputs(vision_outputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + return {name: np.zeros_like(value) for name, value in vision_outputs.items()} + + @staticmethod + def _run_ai100_vision(vision_qpc_path: str, prepared_inputs: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]: + vision_session = QAICInferenceSession(vision_qpc_path) + vision_inputs = { + "pixel_values": prepared_inputs["pixel_values"].detach().cpu().numpy().astype(np.float16), + "image_grid_thw": prepared_inputs["image_grid_thw"].detach().cpu().numpy().astype(np.int64), + } + vision_outputs = vision_session.run(vision_inputs) + vision_session.deactivate() + return vision_outputs + + @staticmethod + def _run_ai100_prefill( + qpc_paths: Dict[str, str], + prepared_inputs: Dict[str, torch.Tensor], + vision_outputs: Dict[str, np.ndarray], + ) -> np.ndarray: + lang_qpc_path = qpc_paths.get("lang_qpc_path") + if lang_qpc_path is None: + raise ValueError("Missing lang_qpc_path in compiled QPC outputs.") + + prefill_len = prepared_inputs["position_ids"].shape[-1] + input_ids = prepared_inputs["input_ids"] + if input_ids.shape[1] < prefill_len: + pad = torch.full( + (input_ids.shape[0], prefill_len - input_ids.shape[1]), + 1, + dtype=input_ids.dtype, + device=input_ids.device, + ) + input_ids = torch.cat([input_ids, pad], dim=1) + else: + input_ids = input_ids[:, :prefill_len] + + position_ids = prepared_inputs["position_ids"][..., :prefill_len] + + lang_session = QAICInferenceSession(lang_qpc_path) + lang_session.skip_buffers( + [ + name + for name in lang_session.input_names + lang_session.output_names + if name.startswith("past_") or name.endswith("_RetainedState") + ] + ) + lang_session.set_buffers(vision_outputs) + lang_inputs = { + "input_ids": input_ids.detach().cpu().numpy().astype(np.int64), + "position_ids": position_ids.detach().cpu().numpy().astype(np.int64), + "image_idx": np.zeros((1, 1), dtype=np.int64), + } + + outputs = lang_session.run(lang_inputs) + lang_session.deactivate() + + if "embedding_output" not in outputs: + raise KeyError( + "Missing 'embedding_output' in AI100 decoder outputs. Ensure export_embedding is enabled in config/qaic_config." + ) + embedding_output = outputs["embedding_output"] + if embedding_output.ndim > 2: + embedding_output = embedding_output.reshape(embedding_output.shape[0], -1) + return embedding_output + + def _compile_if_needed(self, tokenized_inputs_list: List[Dict[str, torch.Tensor]]) -> Tuple[Dict[str, str], int]: + max_prompt_len = 0 + max_grid_h = 22 + max_grid_w = 34 + + for tokenized in tokenized_inputs_list: + max_prompt_len = max(max_prompt_len, int(tokenized["input_ids"].shape[1])) + if "image_grid_thw" in tokenized and tokenized["image_grid_thw"].numel() > 0: + grid = tokenized["image_grid_thw"] + max_grid_h = max(max_grid_h, int(grid[..., 1].max().item())) + max_grid_w = max(max_grid_w, int(grid[..., 2].max().item())) + + effective_prefill = ( + max_prompt_len if self.compile_prefill_seq_len is None else int(self.compile_prefill_seq_len) + ) + if effective_prefill < max_prompt_len: + raise ValueError( + f"compile_prefill_seq_len ({effective_prefill}) must be >= max runtime prompt length ({max_prompt_len})." + ) + + patch_size = int(self.model.model.config.vision_config.patch_size) + compile_height = max_grid_h * patch_size + compile_width = max_grid_w * patch_size + + if ( + self._compiled_qpc_paths is not None + and self._compiled_prefill_seq_len == effective_prefill + and self._compiled_height == compile_height + and self._compiled_width == compile_width + ): + return self._compiled_qpc_paths, effective_prefill + + qpc_paths = self.model.compile( + img_size=max(compile_height, compile_width), + height=compile_height, + width=compile_width, + prefill_seq_len=effective_prefill, + ctx_len=self.ctx_len, + num_devices=self.num_devices, + num_cores=self.num_cores, + mxfp6_matmul=self.mxfp6_matmul, + ) + + self._compiled_qpc_paths = qpc_paths + self._compiled_prefill_seq_len = effective_prefill + self._compiled_height = compile_height + self._compiled_width = compile_width + return qpc_paths, effective_prefill + + def process(self, inputs: List[Dict[str, Any]], normalize: bool = True) -> torch.Tensor: + conversations = [ + self.format_model_input( + text=entry.get("text"), + image=entry.get("image"), + video=entry.get("video"), + instruction=entry.get("instruction"), + ) + for entry in inputs + ] + + tokenized_inputs_list = [self._tokenize_conversation(conversation) for conversation in conversations] + qpc_paths, prefill_seq_len = self._compile_if_needed(tokenized_inputs_list) + + prepared_inputs_list = [ + self._prepare_qeff_inputs(self.model, tokenized_inputs, prefill_seq_len=prefill_seq_len) + for tokenized_inputs in tokenized_inputs_list + ] + + vision_template = None + for prepared_inputs in prepared_inputs_list: + if "pixel_values" in prepared_inputs and "image_grid_thw" in prepared_inputs: + vision_template = self._run_ai100_vision(qpc_paths["vision_qpc_path"], prepared_inputs) + break + + if vision_template is None: + raise ValueError("At least one input with an image is required to initialize the vision path.") + + embedding_rows = [] + for prepared_inputs in prepared_inputs_list: + if "pixel_values" in prepared_inputs and "image_grid_thw" in prepared_inputs: + vision_outputs = self._run_ai100_vision(qpc_paths["vision_qpc_path"], prepared_inputs) + else: + vision_outputs = self._zero_vision_outputs(vision_template) + + embedding_output = self._run_ai100_prefill( + qpc_paths=qpc_paths, + prepared_inputs=prepared_inputs, + vision_outputs=vision_outputs, + ) + embedding_rows.append(torch.from_numpy(embedding_output).to(torch.float32)) + + embeddings = torch.cat(embedding_rows, dim=0) + if normalize: + embeddings = F.normalize(embeddings, p=2, dim=-1) + return embeddings + + +def parse_args(): + parser = argparse.ArgumentParser(description="Qwen3-VL-Embedding AI100 inference") + parser.add_argument("--model-name", type=str, default=DEFAULT_MODEL_NAME) + parser.add_argument("--ctx-len", type=int, default=DEFAULT_CTX_LEN) + parser.add_argument("--num-cores", type=int, default=DEFAULT_NUM_CORES) + parser.add_argument("--num-devices", type=int, default=DEFAULT_NUM_DEVICES) + parser.add_argument("--mxfp6-matmul", action="store_true") + parser.add_argument("--compile-prefill-seq-len", type=int, default=None) + parser.add_argument("--num-hidden-layers", type=int, default=DEFAULT_NUM_HIDDEN_LAYERS) + parser.add_argument("--vision-depth", type=int, default=DEFAULT_VISION_DEPTH) + parser.add_argument("--deepstack-index", type=int, default=DEFAULT_DEEPSTACK_INDEX) + return parser.parse_args() + + +def main(): + args = parse_args() + + queries = [ + {"text": "A woman playing with her dog on a beach at sunset."}, + {"text": "Pet owner training dog outdoors near water."}, + {"text": "Woman surfing on waves during a sunny day."}, + {"text": "City skyline view from a high-rise building at night."}, + ] + + documents = [ + { + "text": "A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset, " + "as the dog offers its paw in a heartwarming display of companionship and trust." + }, + {"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"}, + { + "text": "A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset, " + "as the dog offers its paw in a heartwarming display of companionship and trust.", + "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + }, + ] + + embedder = QEffQwen3VLEmbedder( + model_name_or_path=args.model_name, + ctx_len=args.ctx_len, + num_cores=args.num_cores, + num_devices=args.num_devices, + mxfp6_matmul=args.mxfp6_matmul, + compile_prefill_seq_len=args.compile_prefill_seq_len, + num_hidden_layers=args.num_hidden_layers, + vision_depth=args.vision_depth, + deepstack_index=args.deepstack_index, + ) + + model_inputs = queries + documents + embeddings = embedder.process(model_inputs) + + q_count = len(queries) + similarity_scores = embeddings[:q_count] @ embeddings[q_count:].T + print(similarity_scores.tolist()) + + +if __name__ == "__main__": + main() diff --git a/tests/configs/image_text_model_configs.json b/tests/configs/image_text_model_configs.json index dad511273..fdfc05c45 100644 --- a/tests/configs/image_text_model_configs.json +++ b/tests/configs/image_text_model_configs.json @@ -609,5 +609,18 @@ "additional_params": { } } + ], + "image_text_embedding_models": [ + { + "model_name": "Qwen/Qwen3-VL-Embedding-8B", + "model_type": "qwen3_vl", + "batch_size": 1, + "ctx_len": 2048, + "num_layers": 1, + "vision_depth": 9, + "deepstack_index": 8, + "compile_prefill_seq_len": null, + "mad_max_threshold": 0.001 + } ] -} \ No newline at end of file +} diff --git a/tests/transformers/models/audio_models/test_audio_embedding_models.py b/tests/transformers/models/audio_models/test_audio_embedding_models.py index 64dc06a59..82c613e55 100644 --- a/tests/transformers/models/audio_models/test_audio_embedding_models.py +++ b/tests/transformers/models/audio_models/test_audio_embedding_models.py @@ -139,7 +139,6 @@ def check_ctc_pytorch_vs_kv_vs_ort_vs_ai100( qnn_config: Optional[str] = None, compare_results: Optional[bool] = False, ): - replace_transformers_quantizers() model_config = {"model_name": model_name} model_config["n_layer"] = n_layer @@ -200,7 +199,6 @@ def check_ctc_pytorch_vs_kv_vs_ort_vs_ai100( @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_full_ctc_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) check_ctc_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, compare_results=True, manual_cleanup=manual_cleanup, num_devices=4 @@ -211,7 +209,6 @@ def test_full_ctc_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_few_ctc_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) check_ctc_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=4, manual_cleanup=manual_cleanup) diff --git a/tests/transformers/models/audio_models/test_speech_seq2seq_models.py b/tests/transformers/models/audio_models/test_speech_seq2seq_models.py index 6509d02fe..0c6fb2908 100644 --- a/tests/transformers/models/audio_models/test_speech_seq2seq_models.py +++ b/tests/transformers/models/audio_models/test_speech_seq2seq_models.py @@ -374,7 +374,6 @@ def check_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100( @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_full_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) check_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, compare_results=True, manual_cleanup=manual_cleanup, num_devices=4 @@ -385,7 +384,6 @@ def test_full_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_few_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) check_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=4, manual_cleanup=manual_cleanup) diff --git a/tests/transformers/models/causal_lm_models/check_causal_models.py b/tests/transformers/models/causal_lm_models/check_causal_models.py index cc2d074a0..f878acbe7 100644 --- a/tests/transformers/models/causal_lm_models/check_causal_models.py +++ b/tests/transformers/models/causal_lm_models/check_causal_models.py @@ -57,7 +57,6 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( retain_full_kv: Optional[bool] = None, compare_results: bool = False, ): - torch.manual_seed(42) replace_transformers_quantizers() model_hf = load_hf_causal_lm_model(model_name, num_hidden_layers=n_layer, config=config) diff --git a/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py b/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py index 4bf067e7c..0568939cd 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py +++ b/tests/transformers/models/causal_lm_models/test_causal_lm_blocking_hqkv.py @@ -31,7 +31,6 @@ @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -77,7 +76,6 @@ def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manu @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -123,7 +121,6 @@ def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manua @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -178,7 +175,6 @@ def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100(model_name, man @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -244,7 +240,6 @@ def test_full_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, m @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 @@ -310,7 +305,6 @@ def test_few_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, ma @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV[:1]) def test_dummy_causal_all_blocking_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - HEAD_BLOCK_SIZE = 8 NUM_KV_BLOCKS = 2 NUM_Q_BLOCKS = 2 diff --git a/tests/transformers/models/causal_lm_models/test_causal_lm_models.py b/tests/transformers/models/causal_lm_models/test_causal_lm_models.py index 8dbb0915b..8c61cdc98 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_lm_models.py +++ b/tests/transformers/models/causal_lm_models/test_causal_lm_models.py @@ -33,7 +33,6 @@ @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_causal) def test_full_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - if model_name in ModelConfig.FULL_MODEL_TESTS_TO_SKIP: pytest.skip(f"Skipping full model test for {model_name} due to resource constraints.") check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( @@ -55,7 +54,6 @@ def test_few_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup) @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_causal) def test_dummy_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - custom_config = model_config_dict[model_name] hf_config = AutoConfig.from_pretrained( model_name, @@ -89,7 +87,6 @@ def test_full_causal_lm_pytorch_vs_ort_vs_ai100_cb(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_causal) def test_few_causal_lm_pytorch_vs_ort_vs_ai100_cb(model_name, manual_cleanup): - n_layer = get_custom_n_layers(model_name) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, @@ -104,7 +101,6 @@ def test_few_causal_lm_pytorch_vs_ort_vs_ai100_cb(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_causal) def test_dummy_causal_lm_pytorch_vs_ort_vs_ai100_cb(model_name, manual_cleanup): - custom_config = model_config_dict[model_name] hf_config = AutoConfig.from_pretrained( model_name, diff --git a/tests/transformers/models/causal_lm_models/test_causal_lm_pl1.py b/tests/transformers/models/causal_lm_models/test_causal_lm_pl1.py index b6641d795..f5f2384e6 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_lm_pl1.py +++ b/tests/transformers/models/causal_lm_models/test_causal_lm_pl1.py @@ -32,7 +32,6 @@ @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_full_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") @@ -52,7 +51,6 @@ def test_full_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_ful @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_few_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") torch.manual_seed(42) @@ -71,7 +69,6 @@ def test_few_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_full @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_dummy_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") @@ -97,7 +94,6 @@ def test_dummy_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(model_name, retain_fu @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_full_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1_CB(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") torch.manual_seed(42) @@ -117,7 +113,6 @@ def test_full_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1_CB(model_name, retain_ @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_few_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1_CB(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") torch.manual_seed(42) @@ -137,7 +132,6 @@ def test_few_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1_CB(model_name, retain_f @pytest.mark.parametrize("model_name", test_models_pl1) @pytest.mark.parametrize("retain_full_kv", [True, False]) def test_dummy_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1_CB(model_name, retain_full_kv, manual_cleanup): - if model_name == "gpt2" and retain_full_kv: pytest.skip("Skipping test for gpt2 with retain_full_kv=True as it is not supported.") diff --git a/tests/transformers/models/causal_lm_models/test_causal_tlm_models.py b/tests/transformers/models/causal_lm_models/test_causal_tlm_models.py index 0b488a503..9d02acbd2 100644 --- a/tests/transformers/models/causal_lm_models/test_causal_tlm_models.py +++ b/tests/transformers/models/causal_lm_models/test_causal_tlm_models.py @@ -32,7 +32,6 @@ @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_full_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, num_speculative_tokens=Constants.NUM_SPECULATIVE_TOKENS, @@ -46,7 +45,6 @@ def test_full_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanu @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_few_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - n_layer = get_custom_n_layers(model_name) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, @@ -61,7 +59,6 @@ def test_few_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_dummy_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_cleanup): - custom_config = model_config_dict[model_name] hf_config = AutoConfig.from_pretrained( model_name, @@ -81,7 +78,6 @@ def test_dummy_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, manual_clean @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_full_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, num_speculative_tokens=Constants.NUM_SPECULATIVE_TOKENS, @@ -96,7 +92,6 @@ def test_full_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cle @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_few_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - n_layer = get_custom_n_layers(model_name) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name=model_name, @@ -112,7 +107,6 @@ def test_few_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_clea @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models_spd) def test_dummy_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100_CB(model_name, manual_cleanup): - custom_config = model_config_dict[model_name] hf_config = AutoConfig.from_pretrained( model_name, diff --git a/tests/transformers/models/causal_lm_models/test_fp16_causal_lm.py b/tests/transformers/models/causal_lm_models/test_fp16_causal_lm.py index 2ff366ece..af8c3b70f 100644 --- a/tests/transformers/models/causal_lm_models/test_fp16_causal_lm.py +++ b/tests/transformers/models/causal_lm_models/test_fp16_causal_lm.py @@ -127,7 +127,6 @@ def check_causal_lm_pytorch_vs_kv_vs_ai100( @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_full_fp16_causal_lm_pytorch_vs_kv_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) check_causal_lm_pytorch_vs_kv_vs_ai100( model_name=model_name, torch_dtype=torch.float16, manual_cleanup=manual_cleanup @@ -139,7 +138,6 @@ def test_full_fp16_causal_lm_pytorch_vs_kv_vs_ai100(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_few_fp16_causal_lm_pytorch_vs_kv_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) n_layer = get_custom_n_layers(model_name) check_causal_lm_pytorch_vs_kv_vs_ai100( @@ -152,7 +150,6 @@ def test_few_fp16_causal_lm_pytorch_vs_kv_vs_ai100(model_name, manual_cleanup): @pytest.mark.llm_model @pytest.mark.parametrize("model_name", test_models) def test_dummy_fp16_causal_lm_pytorch_vs_kv_vs_ai100(model_name, manual_cleanup): - torch.manual_seed(42) custom_config = model_config_dict[model_name] hf_config = AutoConfig.from_pretrained( diff --git a/tests/transformers/models/image_text_to_text/test_custom_dtype.py b/tests/transformers/models/image_text_to_text/test_custom_dtype.py index 95f62f1ac..f291c5d12 100644 --- a/tests/transformers/models/image_text_to_text/test_custom_dtype.py +++ b/tests/transformers/models/image_text_to_text/test_custom_dtype.py @@ -41,7 +41,6 @@ def test_full_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100_custom_dtype( model_name, kv_offload, torch_dtype, manual_cleanup ): - if model_name in ModelConfig.SKIPPED_MODELS: pytest.skip("Test skipped for this model due to some issues.") if model_name in ModelConfig.DUAL_QPC_MODELS and not kv_offload: @@ -65,7 +64,6 @@ def test_full_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100_custom_dtype( def test_few_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100_custom_dtype( model_name, kv_offload, torch_dtype, manual_cleanup ): - if model_name in ModelConfig.SKIPPED_MODELS: pytest.skip("Test skipped for this model due to some issues.") if model_name in ModelConfig.DUAL_QPC_MODELS and not kv_offload: diff --git a/tests/transformers/models/image_text_to_text/test_embedding_mad.py b/tests/transformers/models/image_text_to_text/test_embedding_mad.py new file mode 100644 index 000000000..5794f469a --- /dev/null +++ b/tests/transformers/models/image_text_to_text/test_embedding_mad.py @@ -0,0 +1,130 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import importlib.util +import json +import os +from pathlib import Path +from typing import Any, Dict, List + +import pytest +import torch +import torch.nn.functional as F +from huggingface_hub import snapshot_download +from transformers import AutoConfig + +from QEfficient.utils.test_utils import load_vlm_model + +CONFIG_PATH = "tests/configs/image_text_model_configs.json" + +DEFAULT_MAD_MAX = 1e-3 + +EXAMPLE_QUERIES = [ + {"text": "A woman playing with her dog on a beach at sunset."}, +] + +EXAMPLE_DOCUMENTS = [ + {"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"}, +] + +with open(CONFIG_PATH, "r") as f: + config_data = json.load(f) + embedding_models = config_data["image_text_embedding_models"] + +test_embedding_models = [model_config["model_name"] for model_config in embedding_models] +embedding_model_config_dict = {model["model_name"]: model for model in embedding_models} + + +def _load_embedder_cls(): + repo_root = Path(__file__).resolve().parents[4] + example_path = repo_root / "examples/image_text_to_text/models/qwen3vl/embedding/qwen3_vl_embedding.py" + spec = importlib.util.spec_from_file_location("qwen3_vl_embedding_example", str(example_path)) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module.QEffQwen3VLEmbedder + + +def _resolve_model_source(model_name_or_path: str) -> str: + if os.path.isdir(model_name_or_path): + return model_name_or_path + return snapshot_download(repo_id=model_name_or_path) + + +def _compute_cpu_embeddings(model_hf, embedder, model_inputs: List[Dict[str, Any]]) -> torch.Tensor: + embedding_rows = [] + for entry in model_inputs: + conversation = embedder.format_model_input( + text=entry.get("text"), + image=entry.get("image"), + video=entry.get("video"), + instruction=entry.get("instruction"), + ) + tokenized = embedder._tokenize_conversation(conversation) + hf_inputs = {} + for key, value in tokenized.items(): + hf_inputs[key] = value.to(next(model_hf.parameters()).device) if torch.is_tensor(value) else value + + with torch.no_grad(): + last_hidden_state = model_hf.model(**hf_inputs).last_hidden_state + + last_idx = tokenized["input_ids"].shape[1] - 1 + row = last_hidden_state[:, last_idx : last_idx + 1, :].reshape(last_hidden_state.shape[0], -1) + embedding_rows.append(row.detach().cpu().to(torch.float32)) + + embeddings = torch.cat(embedding_rows, dim=0) + return F.normalize(embeddings, p=2, dim=-1) + + +@pytest.mark.on_qaic +@pytest.mark.multimodal +@pytest.mark.regular +@pytest.mark.parametrize("model_name", test_embedding_models) +def test_qwen3_vl_embedding_cpu_vs_ai100_mad_parity(model_name): + torch.manual_seed(42) + model_cfg = embedding_model_config_dict[model_name] + model_source = _resolve_model_source(model_name) + + config = AutoConfig.from_pretrained(model_source, trust_remote_code=True, padding=True) + if hasattr(config, "use_cache"): + config.use_cache = True + if hasattr(config, "text_config") and hasattr(config.text_config, "use_cache"): + config.text_config.use_cache = True + + config.text_config.num_hidden_layers = model_cfg["num_layers"] + config.vision_config.depth = model_cfg["vision_depth"] + config.vision_config.deepstack_visual_indexes = [model_cfg["deepstack_index"]] + + model_hf = load_vlm_model(config) + model_hf.eval() + + QEffQwen3VLEmbedder = _load_embedder_cls() + embedder = QEffQwen3VLEmbedder( + model_name_or_path=model_source, + ctx_len=model_cfg["ctx_len"], + num_cores=16, + num_devices=1, + compile_prefill_seq_len=model_cfg.get("compile_prefill_seq_len", None), + num_hidden_layers=model_cfg["num_layers"], + vision_depth=model_cfg["vision_depth"], + deepstack_index=model_cfg["deepstack_index"], + ) + + model_inputs = EXAMPLE_QUERIES + EXAMPLE_DOCUMENTS + cpu_embeddings = _compute_cpu_embeddings(model_hf=model_hf, embedder=embedder, model_inputs=model_inputs) + ai100_embeddings = embedder.process(model_inputs, normalize=True) + + diff = torch.abs(cpu_embeddings - ai100_embeddings) + mad_mean = float(diff.mean().item()) + mad_max = float(diff.max().item()) + threshold = float(model_cfg.get("mad_max_threshold", DEFAULT_MAD_MAX)) + + print(f"[MAD] CPU vs AI100 mean={mad_mean:.6e}, max={mad_max:.6e}") + assert mad_max <= threshold, ( + f"CPU vs AI100 MAD max {mad_max:.6e} exceeds threshold {threshold:.6e}. " + f"Check prompt formatting, tokenization, prompt-length handling, and AI100 compile args." + ) diff --git a/tests/transformers/subfunction/test_causal_lm_blocking_subfunction.py b/tests/transformers/subfunction/test_causal_lm_blocking_subfunction.py index 5c5850838..b3f42e1b0 100644 --- a/tests/transformers/subfunction/test_causal_lm_blocking_subfunction.py +++ b/tests/transformers/subfunction/test_causal_lm_blocking_subfunction.py @@ -64,7 +64,6 @@ def check_blockedKV_onnx_function_count_with_subfunction( @pytest.mark.feature @pytest.mark.parametrize("model_name", test_models_blockedKV) def test_full_blockedKV_onnx_function_count_with_subfunction(model_name, manual_cleanup): - # Keep model small for test runtime, and avoid CB path (not needed for function count). check_blockedKV_onnx_function_count_with_subfunction(model_name, manual_cleanup=manual_cleanup) @@ -73,7 +72,6 @@ def test_full_blockedKV_onnx_function_count_with_subfunction(model_name, manual_ @pytest.mark.feature @pytest.mark.parametrize("model_name", test_models_blockedKV) def test_few_blockedKV_onnx_function_count_with_subfunction(model_name, manual_cleanup): - # Keep model small for test runtime, and avoid CB path (not needed for function count). n_layer = get_custom_n_layers(model_name) @@ -84,7 +82,6 @@ def test_few_blockedKV_onnx_function_count_with_subfunction(model_name, manual_c @pytest.mark.feature @pytest.mark.parametrize("model_name", test_models_blockedKV) def test_dummy_blockedKV_onnx_function_count_with_subfunction(model_name, manual_cleanup): - # Keep model small for test runtime, and avoid CB path (not needed for function count). hf_config = AutoConfig.from_pretrained( model_name, diff --git a/tests/transformers/subfunction/test_subfunction_vlm.py b/tests/transformers/subfunction/test_subfunction_vlm.py index baf690e63..39e2c6d0a 100644 --- a/tests/transformers/subfunction/test_subfunction_vlm.py +++ b/tests/transformers/subfunction/test_subfunction_vlm.py @@ -50,7 +50,6 @@ def check_image_text_to_text_subfunction_core( num_devices: int = 1, config: Optional[AutoConfig] = None, ): - img_size = model_config_dict[model_name]["img_size"] img_url = model_config_dict[model_name]["img_url"] query = model_config_dict[model_name]["text_prompt"] @@ -117,7 +116,6 @@ def check_image_text_to_text_subfunction_core( @pytest.mark.parametrize("model_name", test_mm_models) @pytest.mark.parametrize("kv_offload", [True]) def test_full_image_text_to_text_subfunction(model_name, kv_offload, manual_cleanup): - torch.manual_seed(42) check_image_text_to_text_subfunction_core(model_name, kv_offload=kv_offload, manual_cleanup=manual_cleanup) @@ -127,7 +125,6 @@ def test_full_image_text_to_text_subfunction(model_name, kv_offload, manual_clea @pytest.mark.parametrize("model_name", test_mm_models) @pytest.mark.parametrize("kv_offload", [True]) def test_few_image_text_to_text_subfunction(model_name, kv_offload, manual_cleanup): - torch.manual_seed(42) check_image_text_to_text_subfunction_core( model_name, @@ -142,7 +139,6 @@ def test_few_image_text_to_text_subfunction(model_name, kv_offload, manual_clean @pytest.mark.parametrize("model_name", test_mm_models) @pytest.mark.parametrize("kv_offload", [True]) def test_dummy_image_text_to_text_subfunction(model_name, kv_offload, manual_cleanup): - torch.manual_seed(42) hf_config = AutoConfig.from_pretrained( model_name, trust_remote_code=True, **model_config_dict[model_name].get("additional_params", {})