diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 8e60d698b..be4b86321 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -6,16 +6,21 @@ # ----------------------------------------------------------------------------- import os +import warnings + +from QEfficient.utils import custom_format_warning # For faster downloads via hf_transfer # This code is put above import statements as this needs to be executed before # hf_transfer is imported (will happen on line 15 via leading imports) os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" - # Placeholder for all non-transformer models registered in QEfficient import QEfficient.utils.model_registery # noqa: F401 from QEfficient.utils.logging_utils import logger +# custom warning for the better logging experience +warnings.formatwarning = custom_format_warning + def check_qaic_sdk(): """Check if QAIC SDK is installed""" diff --git a/QEfficient/transformers/embeddings/__init__.py b/QEfficient/transformers/embeddings/__init__.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/transformers/embeddings/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/embeddings/embedding_utils.py b/QEfficient/transformers/embeddings/embedding_utils.py new file mode 100644 index 000000000..dd68e5fb9 --- /dev/null +++ b/QEfficient/transformers/embeddings/embedding_utils.py @@ -0,0 +1,125 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import inspect +from typing import Optional + +import torch +import torch.nn as nn + + +def mean_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + """ + Performs mean pooling on the last hidden states of a transformer model. + + Args: + last_hidden_states (torch.Tensor): The last hidden states of the transformer model. + attention_mask (torch.Tensor): The attention mask used to mask out padding tokens. + + Returns: + torch.Tensor: The mean pooled last hidden states. + """ + input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float() + return torch.sum(last_hidden_states * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + + +def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + """ + Performs average pooling on the last hidden states of a transformer model. + + Args: + last_hidden_states (torch.Tensor): The last hidden states of the transformer model. + attention_mask (torch.Tensor): The attention mask used to mask out padding tokens. + + Returns: + torch.Tensor: The average pooled last hidden states. + """ + last_hidden = last_hidden_states[0].masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + +def max_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + """ + Performs max pooling on the last hidden states of a transformer model. + + Args: + last_hidden_states (torch.Tensor): The last hidden states of the transformer model. + attention_mask (torch.Tensor): The attention mask used to mask out padding tokens. + + Returns: + torch.Tensor: The max pooled last hidden states. + """ + input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float() + last_hidden_states[input_mask_expanded == 0] = -1e9 + return torch.max(last_hidden_states, 1)[0] + + +def cls_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + """ + Performs CLS pooling on the last hidden states of a transformer model. + + Args: + last_hidden_states (torch.Tensor): The last hidden states of the transformer model. + attention_mask (torch.Tensor): The attention mask used to mask out padding tokens. + + Returns: + torch.Tensor: The CLS pooled last hidden states. + """ + return last_hidden_states[:, 0] + + +POOLING_MAP = { + "mean": mean_pooling, + "avg": average_pool, + "cls": cls_pooling, + "max": max_pooling, +} + + +class PooledModel(nn.Module): + """ + Adds pooling functionality to embedding model. + """ + + def __init__(self, base_model, pooling_fn): + super().__init__() + self.config = base_model.config + self.base_model = base_model + self.pooling_fn = pooling_fn + + def forward( + self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs + ): + output = self.base_model(input_ids, attention_mask, **kwargs) + return self.pooling_fn(output[0], attention_mask) + + +def validate_user_pooling_function(user_function): + """ + Validate a user-provided pooling function to ensure it meets the required interface. + + The function should take two arguments: + - last_hidden_states (torch.Tensor): The last hidden states of the model. + - attention_mask (torch.Tensor): The attention mask of the input sequence. + + It should return a torch.Tensor representing the pooled output. + + Args: + user_function (callable): The user-provided pooling function. + + Raises: + ValueError: If the user-provided function does not meet the required interface. + """ + + if not callable(user_function): + raise TypeError("Provided pooling function is not callable.") + + sig = inspect.signature(user_function) + required_args = {"last_hidden_states", "attention_mask"} + if not required_args.issubset(sig.parameters.keys()): + raise ValueError(f"Pooling function must accept arguments: {required_args}") + return user_function diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index dcb0f2306..dc5570dc5 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -39,6 +39,7 @@ CustomOpsTransform, KVCacheModuleMethodMapperTransform, KVCacheTransform, + PoolingTransform, SpDTransform, VlmKVOffloadTransform, VlmNoKVOffloadTransform, @@ -157,23 +158,35 @@ class QEFFAutoModel(QEFFTransformersBase): _pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model: nn.Module, **kwargs): + def __init__(self, model: nn.Module, pooling=None, **kwargs): super().__init__(model) - self.model.config.use_cache = True - self.num_layers = model.config.num_hidden_layers + + # Make Embedding specific transforms like appending pooling + if pooling: + self.model, _ = PoolingTransform.apply(self.model, pooling) + + self.model.base_model.config.use_cache = True + self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) @classmethod @with_replaced_quantizers - def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **kwargs): """ This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModel. Once the model is initialized, you can use other methods such as export, compile, and generate on the same object. This API can also be used as exception for VLM model since transformers support loading InternChatVL models via AutoModel API we support it via AutoModelForCausalLM API Args: - :pretrained_name_or_path (str): Model card name from HuggingFace or local path to model directory. - :args, kwargs: Additional arguments to pass to transformers.AutoModel. + pretrained_model_name_or_path (str): The name or path of the pre-trained model. + pooling (Optional[Union[str, Callable]], optional): The pooling method to use. Defaults to None. + Options: + - "mean": Mean pooling + - "max": Max pooling + - "cls": CLS token pooling + - "avg": Average pooling + - Callable: A custom pooling function + - None: No pooling applied .. code-block:: python @@ -181,7 +194,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): from transformers import AutoTokenizer # Initialize the model using from_pretrained similar to transformers.AutoModel. - model = QEFFAutoModel.from_pretrained("model_name") + model = QEFFAutoModel.from_pretrained("model_name", pooling="mean") # Now you can directly compile the model for Cloud AI 100 model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU @@ -199,13 +212,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): if kwargs.get("low_cpu_mem_usage", None): logger.warning("Updating low_cpu_mem_usage=False") - kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False, "add_pooling_layer": False}) - try: - model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) - warnings.warn("Removing pooling layer from the model if exist") - except TypeError: - kwargs.pop("add_pooling_layer", None) - model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) # This is support models that should be classified to in a different auto class but transformers load them via this class kv_offload = kwargs.pop("kv_offload", None) @@ -214,7 +223,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): model, kv_offload=kv_offload ) - return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path) + return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, pooling=pooling, **kwargs) @property def model_hash(self) -> str: @@ -272,7 +281,7 @@ def compile( onnx_path: Optional[str] = None, compile_dir: Optional[str] = None, *, - seq_len: int = 32, + seq_len: Union[int, List[int]] = 32, batch_size: int = 1, num_devices: int = 1, num_cores: int = 16, # FIXME: Make this mandatory arg @@ -287,7 +296,7 @@ def compile( ``Optional`` Args: :onnx_path (str, optional): Path to pre-exported onnx model. :compile_dir (str, optional): Path for saving the qpc generated. - :seq_len (int, optional): The length of the prompt should be less that ``seq_len``. ``Defaults to 32``. + :seq_len (Union[int, List[int]]): The length of the prompt should be less that ``seq_len``. ``Defaults to 32``. :batch_size (int, optional): Batch size. ``Defaults to 1``. :num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1. :num_cores (int): Number of cores used to compile the model. @@ -303,8 +312,11 @@ def compile( :str: Path of the compiled ``qpc`` package. """ + if isinstance(seq_len, list) and len(seq_len) >= 15: + warnings.warn("Recommended: `seq_len` should contain fewer than 15 items.") + specializations = [ - {"batch_size": batch_size, "seq_len": seq_len}, + {"batch_size": batch_size, "seq_len": sl} for sl in (seq_len if isinstance(seq_len, list) else [seq_len]) ] return self._compile( @@ -365,11 +377,22 @@ def cloud_ai_100_feature_generate( if self.qpc_session is None: self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids) self.batch_size = self.qpc_session.bindings[0].dims[0] - self.seq_len = self.qpc_session.bindings[0].dims[1] - # Prepare input + + # Dynamic switching to closest seq_Len based on input_ids_len input_ids_len = inputs["input_ids"].shape[1] + + for allowed_shape in self.qpc_session.allowed_shapes: + seq_len_allowed = allowed_shape[1][1][1] + + if seq_len_allowed >= input_ids_len: + self.seq_len = seq_len_allowed + break + + # To handle single seq_len as we can't fetch allowed shapes for single seq_len + self.seq_len = self.qpc_session.bindings[0].dims[1] if not hasattr(self, "seq_len") else self.seq_len + input_ids = np.array( - torch.nn.functional.pad(inputs["input_ids"], (0, self.seq_len - inputs["input_ids"].size(1)), "constant", 0) + torch.nn.functional.pad(inputs["input_ids"], (0, self.seq_len - input_ids_len), "constant", 0) ) attention_mask = np.array( torch.nn.functional.pad( @@ -379,14 +402,21 @@ def cloud_ai_100_feature_generate( inputs = dict(input_ids=input_ids, attention_mask=attention_mask) - outputs = { - "output": np.random.randn(self.batch_size, self.seq_len, self.qpc_session.bindings[2].dims[2]).astype( - np.float32 - ), - } - self.qpc_session.set_buffers(outputs) - outputs = self.qpc_session.run(inputs) - outputs = outputs["output"][:, :input_ids_len, :] + # TODO: Remove try and catch after compiler fix + try: + outputs = { + "output": np.random.randn(*list(self.qpc_session.bindings[2].dims)).astype(np.float32), + } + self.qpc_session.set_buffers(outputs) + outputs = self.qpc_session.run(inputs) + except Exception: + outputs = { + "output": np.random.randn(self.batch_size, self.seq_len, self.qpc_session.bindings[2].dims[1]).astype( + np.float32 + ), + } + self.qpc_session.set_buffers(outputs) + outputs = self.qpc_session.run(inputs) return outputs def pytorch_feature_generate(self, model, inputs: Union[torch.Tensor, np.ndarray]) -> List[torch.Tensor]: diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index edac05248..fe2a9729a 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -5,8 +5,9 @@ # # ----------------------------------------------------------------------------- +import warnings from types import MethodType -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple, Union from torch import nn from transformers.models.codegen.modeling_codegen import ( @@ -145,6 +146,7 @@ from QEfficient.base.pytorch_transforms import ModuleMappingTransform, ModuleMethodMapperTransform from QEfficient.customop import CustomRMSNormAIC, GemmaCustomRMSNormAIC +from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function from QEfficient.transformers.models.codegen.modeling_codegen import ( QEffCodeGenAttention, QeffCodeGenBlock, @@ -524,3 +526,22 @@ class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform): "InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward}, } _match_class_replace_method = {} + + +class PoolingTransform: + """ + Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output. + The pooling layer can be configured to use different pooling methods, such as max pooling or average pooling. + """ + + @classmethod + def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Module, bool]: + transformed = False + pooling_method = ( + POOLING_MAP[pooling] + if isinstance(pooling, str) and pooling in POOLING_MAP + else validate_user_pooling_function(pooling) + ) + model = PooledModel(model, pooling_method) + warnings.warn("Pooling is applied to the model.") + return model, transformed diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index f73998302..8294a3d0a 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -11,6 +11,7 @@ ) from QEfficient.utils._utils import ( # noqa: F401 check_and_assign_cache_dir, + custom_format_warning, dump_qconfig, get_num_layers_from_config, get_num_layers_vlm, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index f8bc5753c..f8adff236 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -662,3 +662,9 @@ def filter_kwargs(func, kwargs): """ valid_args = inspect.signature(func).parameters return {key: value for key, value in kwargs.items() if key in valid_args} + + +def custom_format_warning(msg, category, *args, **kwargs): + YELLOW = "\033[93m" + RESET = "\033[0m" + return f"{YELLOW}[Warning]: {msg}{RESET}\n" diff --git a/examples/embedding_model.py b/examples/embedding_model.py index ecced4259..23c9cfb3d 100644 --- a/examples/embedding_model.py +++ b/examples/embedding_model.py @@ -1,24 +1,23 @@ # ----------------------------------------------------------------------------- -# + # Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. # SPDX-License-Identifier: BSD-3-Clause -# + # ----------------------------------------------------------------------------- # This is the work example of the Embedding model with the AI 100 # For more information, visit: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 import torch -import torch.nn.functional as F from transformers import AutoTokenizer from QEfficient import QEFFAutoModel as AutoModel -def mean_pooling(model_output, attention_mask): - token_embeddings = model_output # First element of model_output contains all token embeddings - input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() - return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) +def max_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float() + last_hidden_states[input_mask_expanded == 0] = -1e9 + return torch.max(last_hidden_states, 1)[0] # Sentences we want sentence embeddings for @@ -28,18 +27,22 @@ def mean_pooling(model_output, attention_mask): tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") -qeff_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") -qeff_model.compile(num_cores=14) +# You can specify the pooling strategy either as a string (e.g., "mean") or by passing a custom pooling function. +# If no pooling is specified, the model will return its default output (typically token embeddings). +qeff_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2", pooling=max_pooling) + +# Example: Using mean pooling by specifying it as a string. +# This will return sentence embeddings computed using mean pooling. +# qeff_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") + +# Here seq_len can be list of seq_len or single int +qeff_model.compile(num_cores=16, seq_len=[32, 64]) +# qeff_model.compile(num_cores=16, seq_len=32) + # Tokenize sentences encoded_input = tokenizer(sentences, return_tensors="pt") -qeff_output = torch.tensor(qeff_model.generate(encoded_input)) - -# Perform pooling -sentence_embeddings = mean_pooling(qeff_output, encoded_input["attention_mask"]) -# Normalize embeddings -sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) +sentence_embeddings = qeff_model.generate(encoded_input) -print("Sentence embeddings:") -print(sentence_embeddings) +print("Sentence embeddings:", sentence_embeddings) diff --git a/tests/transformers/models/test_embedding_models.py b/tests/transformers/models/test_embedding_models.py index 71b2ec314..2d110faeb 100644 --- a/tests/transformers/models/test_embedding_models.py +++ b/tests/transformers/models/test_embedding_models.py @@ -11,17 +11,17 @@ import numpy as np import onnxruntime as ort import pytest +import torch from transformers import AutoModel, AutoTokenizer +from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP from QEfficient.transformers.models.modeling_auto import QEFFAutoModel from QEfficient.utils._utils import create_json from QEfficient.utils.constants import Constants, QnnConstants embed_test_models = [ - # model_name, architecture - "sentence-transformers/multi-qa-mpnet-base-cos-v1", # MPNetForMaskedLM - "BAAI/bge-reranker-v2-m3", # XLMRobertaForSequenceClassification - "BAAI/bge-small-en-v1.5", # BertModel + {"model_name": "jinaai/jina-embeddings-v2-base-code", "pooling": "mean"}, + {"model_name": "sentence-transformers/nli-bert-base-cls-pooling", "pooling": "cls"}, ] @@ -31,6 +31,7 @@ def check_embed_pytorch_vs_ort_vs_ai100( n_layer: int = 1, enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, + pooling: Optional[str] = None, ): # Prepare input tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -44,16 +45,27 @@ def check_embed_pytorch_vs_ort_vs_ai100( trust_remote_code=True, ) + # Original PyTorch model output pt_outputs = pt_model(**inputs) - pt_embeddings = pt_outputs[0][0].detach().numpy() - # Pytorch transformed model - qeff_model = QEFFAutoModel(pt_model, pretrained_model_name_or_path=model_name) + pooling_method = POOLING_MAP[pooling] if pooling else None + pt_embeddings = ( + pooling_method(pt_outputs.last_hidden_state, inputs["attention_mask"]) + if pooling + else pt_outputs.last_hidden_state + ) + + # QEff transformed PyTorch model + qeff_model = QEFFAutoModel(pt_model, pretrained_model_name_or_path=model_name, pooling=pooling) + + # QEff transformed PyTorch model output qeff_pt_outputs = qeff_model.generate(inputs=inputs, runtime_ai100=False) - qeff_pt_embeddings = qeff_pt_outputs[0][0].detach().numpy() - mad = np.mean(np.abs(pt_embeddings - qeff_pt_embeddings)) + qeff_pt_embeddings = qeff_pt_outputs if pooling else qeff_pt_outputs[0] + + mad = torch.mean(torch.abs(pt_embeddings - qeff_pt_embeddings)) print("Mad for PyTorch and PyTorch transformed qeff_model is ", mad) assert mad <= 0, f"MAD is too high for onnx and Pytorch: {mad}" + # ONNX session load onnx_model = qeff_model.export() ort_session = ort.InferenceSession(str(onnx_model)) @@ -62,14 +74,12 @@ def check_embed_pytorch_vs_ort_vs_ai100( attention_mask = np.array(inputs["attention_mask"]) onnx_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} + # Run inference onnx_outputs = ort_session.run(None, onnx_inputs) # Compare Transformed PyTorch and ONNX outputs - - pt_embeddings = pt_outputs[0][0].detach().numpy() - onnx_embeddings = onnx_outputs[0] - mad = np.mean(np.abs(pt_embeddings - onnx_embeddings)) + mad = torch.mean(torch.abs(pt_embeddings - torch.tensor(onnx_outputs[0]))) print("Mad for onnx and PyTorch is ", mad) assert mad <= 10**-5, f"MAD is too high for onnx and Pytorch: {mad}" @@ -79,21 +89,45 @@ def check_embed_pytorch_vs_ort_vs_ai100( qnn_config=qnn_config, ) ai100_output = qeff_model.generate(inputs=inputs) + qeff_ai100_embeddings = ( + ai100_output["output"] if pooling else ai100_output["output"][:, : inputs["input_ids"].shape[1], :] + ) # Compare ONNX and AI 100 outputs - mad = np.mean(np.abs(ai100_output - onnx_outputs[0])) + mad = np.mean(np.abs(qeff_ai100_embeddings - onnx_outputs[0])) print("Mad for onnx and AI 100 output is ", mad) - assert mad <= 10**-3, f"MAD is too high for onnx and Pytorch: {mad}" + assert mad <= 10**-2, f"MAD is too high for onnx and Pytorch: {mad}" assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json")) @pytest.mark.on_qaic -@pytest.mark.parametrize("model_name", embed_test_models) -def test_embed_model_pytorch_vs_onnx_vs_ai100(model_name): +@pytest.mark.parametrize("model", embed_test_models) +def test_embed_model_pytorch_vs_onnx_vs_ai100(model): """ Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output. """ - check_embed_pytorch_vs_ort_vs_ai100(model_name=model_name, seq_len=32, n_layer=1) + check_embed_pytorch_vs_ort_vs_ai100(model_name=model["model_name"], seq_len=32, n_layer=1) + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("model", embed_test_models) +def test_embed_model_pytorch_vs_onnx_vs_ai100_pooling(model): + """ + Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output with pooling. + """ + check_embed_pytorch_vs_ort_vs_ai100(model_name=model["model_name"], seq_len=32, n_layer=1, pooling=model["pooling"]) + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("model", embed_test_models[:1]) +def test_embed_model_pytorch_vs_onnx_vs_ai100_multiple_seq_len(model): + """ + Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output with multiple seq_len. + """ + check_embed_pytorch_vs_ort_vs_ai100(model_name=model["model_name"], seq_len=[32, 20], n_layer=1) + + +########## QNN TESTS ############## @pytest.mark.on_qaic @@ -108,5 +142,42 @@ def test_embed_model_pytorch_vs_onnx_vs_ai100_qnn(model_name): create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG) check_embed_pytorch_vs_ort_vs_ai100( - model_name=model_name, seq_len=32, n_layer=1, enable_qnn=True, qnn_config=qnn_config_json_path + model_name=model_name["model_name"], seq_len=32, n_layer=1, enable_qnn=True, qnn_config=qnn_config_json_path + ) + + +@pytest.mark.on_qaic +@pytest.mark.qnn +@pytest.mark.parametrize("model", embed_test_models) +def test_embed_model_pytorch_vs_onnx_vs_ai100_pooling_qnn(model): + """ + QNN Compilation path test. + Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output with pooling. + """ + qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json") + create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG) + + check_embed_pytorch_vs_ort_vs_ai100( + model_name=model["model_name"], + seq_len=32, + n_layer=1, + pooling=model["pooling"], + enable_qnn=True, + qnn_config=qnn_config_json_path, + ) + + +@pytest.mark.on_qaic +@pytest.mark.qnn +@pytest.mark.parametrize("model", [embed_test_models[0]]) +def test_embed_model_pytorch_vs_onnx_vs_ai100_multiple_seq_len_qnn(model): + """ + QNN Compilation path test. + Test function to validate output of the Pytorch, ONNX and AI 100 runtime model output with multiple seq_len. + """ + qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json") + create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG) + + check_embed_pytorch_vs_ort_vs_ai100( + model_name=model["model_name"], seq_len=[32, 20], n_layer=1, enable_qnn=True, qnn_config=qnn_config_json_path )