diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py
index 6e21d11b2..abd19ed35 100644
--- a/QEfficient/base/pytorch_transforms.py
+++ b/QEfficient/base/pytorch_transforms.py
@@ -4,7 +4,8 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
-from typing import Dict, Tuple, Type
+from types import MethodType
+from typing import Callable, Dict, Tuple, Type
from torch import nn
@@ -87,3 +88,25 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
@classmethod
def mutate(cls, original_module: nn.Module, parent_module: nn.Module):
raise NotImplementedError("Please implement your own method by inheriting this class")
+
+
+class ModuleMethodMapperTransform(PytorchTransform):
+ """
+ Serves as base class for any transform that want to map a particular method of a class to a new method implementation.
+ """
+
+ _match_class_replace_method: Dict[nn.Module, Dict[str, Callable]]
+ _match_string_replace_method: Dict[str, Dict[str, Callable]]
+
+ @classmethod
+ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
+ transformed = False
+ for module in model.modules():
+ if (repl_method_map := cls._match_class_replace_method.get(type(module))) or (
+ repl_method_map := cls._match_string_replace_method.get(module.__class__.__name__)
+ ):
+ for orig_method_name, mapped_method in repl_method_map.items():
+ setattr(module, orig_method_name, MethodType(mapped_method, module))
+ transformed = True
+
+ return model, transformed
diff --git a/QEfficient/transformers/models/internvl/__init__.py b/QEfficient/transformers/models/internvl/__init__.py
new file mode 100644
index 000000000..72ba36c8a
--- /dev/null
+++ b/QEfficient/transformers/models/internvl/__init__.py
@@ -0,0 +1,6 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py
new file mode 100644
index 000000000..023b09551
--- /dev/null
+++ b/QEfficient/transformers/models/internvl/modeling_internvl.py
@@ -0,0 +1,154 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from QEfficient.utils import constants
+from QEfficient.utils._utils import get_padding_shape_from_config
+
+
+class QEffInternVLModel(nn.Module):
+ def get_specializations(
+ self, batch_size: int, prefill_seq_len: int, ctx_len: int, img_size: int, **compiler_options
+ ):
+ # TODO: check if this should be named num_crops or something else
+ num_crops = compiler_options.get("num_crops", 13)
+ prefill_seq_len = prefill_seq_len if prefill_seq_len else 3840 # 4096-256
+ ctx_len = ctx_len if ctx_len else 4096
+ img_size = img_size if img_size else 448
+
+ return [
+ {
+ "batch_size": batch_size,
+ "seq_len": prefill_seq_len,
+ "ctx_len": ctx_len,
+ "num_crops": num_crops,
+ "img_size": img_size,
+ },
+ {
+ "batch_size": batch_size,
+ "seq_len": "1",
+ "ctx_len": ctx_len,
+ "num_crops": num_crops,
+ "img_size": img_size,
+ },
+ ]
+
+ def get_onnx_dynamic_axes(
+ self,
+ ):
+ # Define dynamic axes
+ dynamic_axes = {}
+ dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
+ dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
+ dynamic_axes["pixel_values"] = {0: "num_crops", 2: "img_size", 3: "img_size"}
+
+ pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"}
+ for i in range(self.language_model.config.num_hidden_layers):
+ for kv in ["key", "value"]:
+ dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
+
+ return dynamic_axes
+
+ def get_output_names(
+ self,
+ ):
+ output_names = ["logits", "pixel_values_RetainedState"]
+ for i in range(self.language_model.config.num_hidden_layers):
+ for kv in ["key", "value"]:
+ output_names.append(f"past_{kv}.{i}_RetainedState")
+ return output_names
+
+ def get_dummy_inputs(self, kv_offload: bool = False):
+ if kv_offload:
+ raise ValueError("kv_offload method not supported for InternVL yet!")
+ NUM_CROPS = 13
+ C, H, W = 3, 448, 448
+
+ # Define shapes
+ inputs_shapes = {}
+ inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
+ inputs_shapes["position_ids"] = (
+ constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
+ constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
+ )
+ inputs_shapes["pixel_values"] = (NUM_CROPS, C, H, W)
+
+ # Define inputs
+ inputs = {}
+ inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
+ inputs["position_ids"] = (
+ torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64)
+ .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
+ .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1)
+ )
+ inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32)
+
+ # Add data for KV
+ kv_cache_shape = get_padding_shape_from_config(
+ config=self.language_model.config,
+ batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
+ seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
+ )
+
+ inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)]
+ for i in range(self.language_model.config.num_hidden_layers):
+ for kv in ["key", "value"]:
+ inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
+
+ return inputs
+
+ def forward(self, input_ids, pixel_values, position_ids, past_key_values):
+ # TODO: Check if Hardcoding this is okay, i.e. check if this value is common for all intern models
+ IMG_CONTEXT_TOKEN = 151667
+
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
+ vit_embeds = self.extract_feature(pixel_values)
+ B, N, C = input_embeds.shape
+ image_input_embeds = input_embeds.reshape(B * N, C)
+ image_input_ids = input_ids.reshape(B * N)
+ selected = image_input_ids == IMG_CONTEXT_TOKEN
+ indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1
+ indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
+ image_features_expanded = vit_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
+ image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
+ inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
+ outputs = self.language_model(
+ inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
+ )
+ return outputs.logits, pixel_values, outputs.past_key_values
+
+
+class QEffInternVisionEmbeddings(nn.Module):
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
+ batch_size, _, height, width = patch_embeds.shape
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+
+ pos_embed = self.position_embedding[:, 1:, :]
+ target_dtype = pos_embed.dtype
+ pos_embed = (
+ pos_embed.float()
+ .reshape(1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1)
+ .permute(0, 3, 1, 2)
+ )
+ pos_embed = (
+ F.interpolate(pos_embed, size=(height, width), mode="bilinear", align_corners=False)
+ .reshape(1, -1, height * width)
+ .permute(0, 2, 1)
+ .to(target_dtype)
+ )
+
+ position_embedding = torch.cat([self.position_embedding[:, :1, :], pos_embed], dim=1)
+
+ embeddings = embeddings + position_embedding.to(target_dtype)
+ return embeddings
diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py
index 4aedf7bfe..c8ee91a70 100644
--- a/QEfficient/transformers/models/mllama/modeling_mllama.py
+++ b/QEfficient/transformers/models/mllama/modeling_mllama.py
@@ -48,10 +48,10 @@
from QEfficient.utils.constants import Constants
bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
-max_num_images = constants.ONNX_EXPORT_MAX_NUM_IMAGES
-max_image_tiles = constants.ONNX_EXPORT_MAX_IMAGE_TILES
-image_size = constants.ONNX_EXPORT_IMAGE_WIDTH
-num_channel = constants.ONNX_EXPORT_IMAGE_DEPTH
+max_num_images = 1
+max_image_tiles = 4
+image_size = 560
+num_channel = 3
seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
@@ -998,7 +998,46 @@ def forward(
)
+class QEffMllamaVisionEncoder(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.model = model
+ self.cross_attention_layers = self.model.config.get_text_config().cross_attention_layers
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ aspect_ratio_mask: Optional[torch.Tensor] = None,
+ aspect_ratio_ids: Optional[torch.Tensor] = None,
+ ) -> List[Tuple[torch.Tensor]]:
+ vision_outputs = self.model.vision_model(
+ pixel_values=pixel_values,
+ aspect_ratio_ids=aspect_ratio_ids,
+ aspect_ratio_mask=aspect_ratio_mask,
+ )
+ cross_attention_states = vision_outputs[0]
+ cross_attention_states = self.model.multi_modal_projector(cross_attention_states).reshape(
+ -1, cross_attention_states.shape[-2], self.model.hidden_size
+ )
+
+ bsz = pixel_values.shape[0]
+ outputs = []
+ for i in self.cross_attention_layers:
+ cross_attn = self.model.language_model.model.layers[i].cross_attn
+ key_states = cross_attn.k_proj(cross_attention_states)
+ value_states = cross_attn.v_proj(cross_attention_states)
+ key_states = key_states.view(bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim).transpose(
+ 1, 2
+ )
+ outputs.append((key_states, value_states))
+ return outputs
+
+
class QEffMllamaForConditionalGeneration(MllamaForConditionalGeneration):
+ def get_qeff_vision_encoder(self):
+ return QEffMllamaVisionEncoder(self)
+
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py
index 1c251961b..ebb457e65 100644
--- a/QEfficient/transformers/models/modeling_auto.py
+++ b/QEfficient/transformers/models/modeling_auto.py
@@ -6,12 +6,11 @@
# ----------------------------------------------------------------------------
import hashlib
-import logging
import sys
import warnings
from pathlib import Path
from time import perf_counter
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Union
import numpy as np
import torch
@@ -32,6 +31,7 @@
from QEfficient.generation.text_generation_inference import CloudAI100ExecInfoNew, PerfMetrics, get_compilation_dims
from QEfficient.transformers.models.pytorch_transforms import (
CustomOpsTransform,
+ KVCacheModuleMethodMapperTransform,
KVCacheTransform,
SpDTransform,
VlmKVOffloadTransorm,
@@ -41,8 +41,7 @@
from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform
from QEfficient.utils import constants, get_padding_shape_from_config
from QEfficient.utils.cache import to_hashable
-
-logger = logging.getLogger(__file__)
+from QEfficient.utils.logging_utils import logger
class QEFFTransformersBase(QEFFBaseModel):
@@ -53,8 +52,10 @@ class QEFFTransformersBase(QEFFBaseModel):
_hf_auto_class: type
def __init__(self, model: nn.Module) -> None:
- if hasattr(model.config, "quantization_config") and not isinstance(
- model.config.quantization_config, tuple(QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.values())
+ if (
+ hasattr(model, "config")
+ and hasattr(model.config, "quantization_config")
+ and not isinstance(model.config.quantization_config, tuple(QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.values()))
):
raise AssertionError("Please use `from_pretrained` method to load quantized models")
@@ -85,112 +86,97 @@ def model_name(self) -> str:
return mname
-class QEFFAutoModelForCausalLM(QEFFTransformersBase):
+class QEFFAutoModel(QEFFTransformersBase):
"""
- The QEFF class is designed for manipulating any causal language model from the HuggingFace hub.
+ The QEFFAutoModel class is designed for manipulating any transformer model from the HuggingFace hub.
Although it is possible to initialize the class directly, we highly recommend using the ``from_pretrained`` method for initialization.
``Mandatory`` Args:
- :model (nn.Module): PyTorch model
- :continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later.
- :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode.
-
+ :model (nn.Module): PyTorch model
.. code-block:: python
- from QEfficient import QEFFAutoModelForCausalLM
+ from QEfficient import QEFFAutoModel
from transformers import AutoTokenizer
- model_name = "gpt2"
- model = QEFFAutoModelForCausalLM.from_pretrained(model_name, num_hidden_layers=2)
- model.compile(prefill_seq_len=128, ctx_len=256, num_cores=16, num_devices=1)
+ # Initialize the model using from_pretrained similar to transformers.AutoModel.
+ model = QEFFAutoModel.from_pretrained("model_name")
+
+ # Now you can directly compile the model for Cloud AI 100
+ model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU
+ #prepare input
tokenizer = AutoTokenizer.from_pretrained(model_name)
- model.generate(prompts=["Hi there!!"], tokenizer=tokenizer)
+ inputs = tokenizer("My name is", return_tensors="pt")
+
+ # You can now execute the model
+ model.generate(inputs)
"""
- _hf_auto_class = AutoModelForCausalLM
- _pytorch_transforms = [AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, CustomOpsTransform, KVCacheTransform]
+ _hf_auto_class = AutoModel
+ _pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
- def __init__(
- self,
- model: nn.Module,
- continuous_batching: bool = False,
- is_tlm: bool = False,
- **kwargs,
- ):
- # model_class_name = model.__class__.__name__
- # if not (model_class_name.endswith("ForCausalLM") or model_class_name.endswith("LMHeadModel")):
- # raise TypeError(f"Required pytorch module for CausalLM or LMHeadModel, got {model_class_name}")
-
- # TODO: remove from version 1.20
- if kwargs.pop("full_batch_size", None):
- continuous_batching = True
- warnings.warn(
- "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2
- )
-
+ def __init__(self, model: nn.Module, **kwargs):
super().__init__(model)
-
- # Set use_cache=True to get KV values as output during ONNX export
self.model.config.use_cache = True
self.num_layers = model.config.num_hidden_layers
- self.continuous_batching = continuous_batching
-
- if is_tlm:
- # TODO: It is possible to always apply this transform and make value of indices as last indices by default in PyTorch
- self.model, transformed = SpDTransform.apply(self.model)
- self.is_tlm = is_tlm
@classmethod
- def from_pretrained(
- cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, *args, **kwargs
- ):
+ @with_replaced_quantizers
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
"""
- This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM.
+ This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModel.
Once the model is initialized, you can use other methods such as export, compile, and generate on the same object.
Args:
:pretrained_name_or_path (str): Model card name from HuggingFace or local path to model directory.
- :continuous_batching (bool): Whether this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later.
- :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode.
- :args, kwargs: Additional arguments to pass to transformers.AutoModelForCausalLM.
+ :args, kwargs: Additional arguments to pass to transformers.AutoModel.
.. code-block:: python
- from QEfficient import QEFFAutoModelForCausalLM
+ from QEfficient import QEFFAutoModel
from transformers import AutoTokenizer
- # Initialize the model using from_pretrained similar to transformers.AutoModelForCausalLM
- model_name = "gpt2"
- model = QEFFAutoModelForCausalLM.from_pretrained(model_name)
+ # Initialize the model using from_pretrained similar to transformers.AutoModel.
+ model = QEFFAutoModel.from_pretrained("model_name")
# Now you can directly compile the model for Cloud AI 100
- model.compile(num_cores=16) # Considering you have a Cloud AI 100 Standard SKU
+ model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU
- # You can now execute the model
+ #prepare input
tokenizer = AutoTokenizer.from_pretrained(model_name)
- model.generate(prompts=["Hi there!!"], tokenizer=tokenizer)
+ inputs = tokenizer("My name is", return_tensors="pt")
+
+ # You can now execute the model
+ model.generate(inputs)
"""
+ if kwargs.get("attn_implementation", None) not in {None, "eager"}:
+ logger.warning('Updating attn_implementation="eager"')
- if kwargs.pop("full_batch_size", None):
- continuous_batching = True
- warnings.warn(
- "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2
- )
+ if kwargs.get("low_cpu_mem_usage", None):
+ logger.warning("Updating low_cpu_mem_usage=False")
- self = super().from_pretrained(pretrained_model_name_or_path, is_tlm=is_tlm, *args, **kwargs)
- self.continuous_batching = continuous_batching
- return self
+ kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False, "add_pooling_layer": False})
+ try:
+ model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
+ warnings.warn("Removing pooling layer from the model if exist")
+ except TypeError:
+ kwargs.pop("add_pooling_layer", None)
+ model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
+ return cls(model)
@property
def model_hash(self) -> str:
- # Compute the hash with: model_config, continuous_batching, transforms
+ # NOTE: model_config.to_diff_dict() has "_name_or_path" attribute which is the model card name or path.
+ # Using same card name will result in same hash. But, using a relative path for one run and
+ # absolute path for another run will result in different hash.
+ # The added complexity to resolve different paths to same location is not worth pursuing.
+ # Instead, advise the user to always provide same relative paths or absolute paths for local models.
+
+ # Compute the hash with: model_config, transforms
mhash = hashlib.sha256()
mhash.update(to_hashable(self.model.config.to_diff_dict()))
- mhash.update(to_hashable({"continuous_batching": self.continuous_batching}))
- mhash.update(to_hashable({"is_tlm": self.is_tlm}))
mhash.update(to_hashable(self._transform_names()))
mhash = mhash.hexdigest()[:16]
return mhash
@@ -200,52 +186,22 @@ def export(self, export_dir: Optional[str] = None) -> str:
Exports the model to ``ONNX`` format using ``torch.onnx.export``.
``Optional`` Args:
- :export_dir (str, optional): The directory path to store ONNX-graph.
+ :export_dir (str, optional): The directory path to store ONNX-graph.
Returns:
:str: Path of the generated ``ONNX`` graph.
"""
- bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
- seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
- fbs = constants.ONNX_EXPORT_EXAMPLE_FBS
- kv_cache_shape = get_padding_shape_from_config(
- self.model.config, fbs if self.continuous_batching else bs, seq_len
- )
+ bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
+ seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
+
example_inputs = {
"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
- "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1),
- "past_key_values": [[] for _ in range(self.num_layers)],
- }
- dynamic_axes = {
- "input_ids": {0: "batch_size", 1: "seq_len"},
- "position_ids": {0: "batch_size", 1: "seq_len"},
+ "attention_mask": torch.ones((bs, seq_len), dtype=torch.int64),
}
- if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d
- pkv_dynamic_axes = {
- 0: "full_batch_size" if self.continuous_batching else "batch_size",
- 1: "ctx_len",
- }
- else: # pkv is 4d
- pkv_dynamic_axes = {
- 0: "full_batch_size" if self.continuous_batching else "batch_size",
- 2: "ctx_len",
- }
- output_names = ["logits"]
-
- for i in range(self.num_layers):
- for kv in ["key", "value"]:
- example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
- dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
- output_names.append(f"past_{kv}.{i}_RetainedState")
- if self.continuous_batching:
- example_inputs["batch_index"] = torch.arange(bs).view(bs, 1)
- dynamic_axes["batch_index"] = {0: "batch_size"}
+ dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}, "attention_mask": {0: "batch_size", 1: "seq_len"}}
- if self.is_tlm:
- nlk = constants.ONNX_EXPORT_EXAMPLE_NLK # Number of Logits to Keep
- example_inputs["num_logits_to_keep"] = torch.arange(nlk).view(nlk, 1)
- dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"}
+ output_names = ["output"]
return self._export(
example_inputs,
@@ -259,18 +215,11 @@ def compile(
onnx_path: Optional[str] = None,
compile_dir: Optional[str] = None,
*,
- prefill_seq_len: int = 32,
- ctx_len: int = 128,
+ seq_len: int = 32,
batch_size: int = 1,
- full_batch_size: Optional[int] = None,
- kv_cache_batch_size: Optional[int] = None,
num_devices: int = 1,
num_cores: int = 16, # FIXME: Make this mandatory arg
mxfp6_matmul: bool = False,
- mxint8_kv_cache: bool = False,
- num_speculative_tokens: Optional[int] = None,
- enable_qnn: bool = False,
- qnn_config: Optional[str] = None,
**compiler_options,
) -> str:
"""
@@ -281,332 +230,32 @@ def compile(
``Optional`` Args:
:onnx_path (str, optional): Path to pre-exported onnx model.
:compile_dir (str, optional): Path for saving the qpc generated.
- :num_cores (int): Number of cores used to compile the model.
- :num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1.
+ :seq_len (int, optional): The length of the prompt should be less that ``seq_len``. ``Defaults to 32``.
:batch_size (int, optional): Batch size. ``Defaults to 1``.
- :prefill_seq_len (int, optional): The length of the Prefill prompt should be less that ``prefill_seq_len``. ``Defaults to 32``.
- :ctx_len (int, optional): Maximum ``ctx`` that the compiled model can remember. ``Defaults to 128``.
- :full_batch_size (int, optional): Continuous batching batch size.
+ :num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1.
+ :num_cores (int): Number of cores used to compile the model.
:mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``.
- :mxint8_kv_cache (bool, optional): Whether to use ``mxint8`` compression for KV cache. ``Defaults to False``.
- :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model.
- :mos (int, optional): Effort level to reduce on-chip memory. Defaults to -1, meaning no effort. ``Defaults to -1``.
:aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
- :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
- :qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
-
+ :allow_mxint8_mdp_io (bool, optional): Allows MXINT8 compression of MDP IO traffic. ``Defaults to False.``
Returns:
:str: Path of the compiled ``qpc`` package.
"""
- if self.is_tlm:
- # assert num_speculative_tokens cfg is acceptable if defined
- if num_speculative_tokens is None:
- raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` is True.")
- if not isinstance(num_speculative_tokens, int) and num_speculative_tokens < 2:
- ValueError(
- f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}"
- )
- num_logits_to_keep = num_speculative_tokens + 1
- if prefill_seq_len < num_logits_to_keep:
- raise ValueError(
- f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})"
- )
-
- if self.continuous_batching and full_batch_size is None:
- raise TypeError("missing required argument: 'full_batch_size'")
-
- if kv_cache_batch_size and not full_batch_size:
- raise ValueError(
- "Prefix caching is enabled only for continuous batching as of now. Please pass `full_batch_size` argument and make sure you pass `continuous_batching=True` in the `from_pretrained` call"
- )
- kv_cache_batch_size = (
- kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size)
- )
- # Define prefill specialization
- prefill_specialization = {
- # Prefill is always run with single BS for continuous batching.
- "batch_size": 1 if self.continuous_batching else batch_size,
- "seq_len": prefill_seq_len,
- "ctx_len": ctx_len,
- # TODO: should be renamed to kv_cache_batch_size in specialization too
- }
- prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ...
- if self.continuous_batching:
- prefill_specialization.update({"full_batch_size": kv_cache_batch_size})
- else:
- prefill_specialization.update({"batch_size": kv_cache_batch_size})
- prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ...
specializations = [
- prefill_specialization,
+ {"batch_size": batch_size, "seq_len": seq_len},
]
- # Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization
- if prefill_seq_len != 1 or self.continuous_batching:
- decode_specialization = {
- "batch_size": full_batch_size if self.continuous_batching else batch_size,
- "seq_len": num_speculative_tokens + 1 if self.is_tlm else 1,
- "ctx_len": ctx_len,
- }
- if self.continuous_batching:
- decode_specialization.update({"full_batch_size": kv_cache_batch_size})
- else:
- decode_specialization.update({"batch_size": kv_cache_batch_size})
- decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ...
- specializations.append(decode_specialization)
-
- if enable_qnn:
- if compiler_options:
- logger.warning("Extra arguments to QNN compilation are supported via qnn_config.json only")
-
- qpc_path = self._qnn_compile(
- onnx_path,
- compile_dir,
- specializations=specializations,
- prefill_seq_len=prefill_seq_len,
- ctx_len=ctx_len,
- batch_size=batch_size,
- full_batch_size=full_batch_size,
- mdp_ts_num_devices=num_devices,
- num_cores=num_cores,
- mxfp6_matmul=mxfp6_matmul,
- mxint8_kv_cache=mxint8_kv_cache,
- qnn_config=qnn_config,
- )
- else:
- # Custom IO
- custom_io = {}
- kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
- for suffix in ["", "_RetainedState"]:
- for i in range(self.num_layers):
- for kv in ["key", "value"]:
- custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype
-
- qpc_path = self._compile(
- onnx_path,
- compile_dir,
- compile_only=True,
- retained_state=True,
- specializations=specializations,
- convert_to_fp16=True,
- mxfp6_matmul=mxfp6_matmul,
- custom_io=custom_io,
- mdp_ts_num_devices=num_devices,
- num_speculative_tokens=num_speculative_tokens,
- aic_num_cores=num_cores,
- **compiler_options,
- )
- return qpc_path
-
- # FIXME: Update this method to match with transformers AutoModelForCausalLM.generate
- def generate(
- self,
- tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer],
- prompts: List[str],
- device_id: List[int] = None,
- runtime_ai100: bool = True,
- **kwargs,
- ):
- """
- This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
- This is a sequential execution based on the ``batch_size`` of the compiled model and the number of prompts passed.
- If the number of prompts cannot be divided by the ``batch_size``, the last unfulfilled batch will be dropped.
-
- ``Mandatory`` Args:
- :tokenizer (Union[PreTrainedTokenizerFast, PreTrainedTokenizer]): Pass tokenizer of the model.
- :prompts (List[str]): List of prompts to run the execution.
-
- ``optional`` Args:
- :device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
- :runtime_ai100 (bool, optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime.
-
- """
- if runtime_ai100:
- if not isinstance(self.qpc_path, Path):
- raise TypeError("Please run compile API first!")
- generation_len = kwargs.pop("generation_len", None)
- return QEfficient.cloud_ai_100_exec_kv(
- tokenizer,
- self.qpc_path,
- prompt=prompts,
- device_id=device_id,
- generation_len=generation_len,
- is_tlm=self.is_tlm,
- )
- else:
- raise NotImplementedError("Only AI_100 runtime is supported right now via generate API")
-
-
-class QEFFAutoModel(QEFFTransformersBase):
- """
- The QEFFAutoModel class is designed for manipulating any transformer model from the HuggingFace hub.
- Although it is possible to initialize the class directly, we highly recommend using the ``from_pretrained`` method for initialization.
-
- ``Mandatory`` Args:
- :model (nn.Module): PyTorch model
-
- .. code-block:: python
-
- from QEfficient import QEFFAutoModel
- from transformers import AutoTokenizer
-
- # Initialize the model using from_pretrained similar to transformers.AutoModel.
- model = QEFFAutoModel.from_pretrained("model_name")
-
- # Now you can directly compile the model for Cloud AI 100
- model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU
-
- #prepare input
- tokenizer = AutoTokenizer.from_pretrained(model_name)
- inputs = tokenizer("My name is", return_tensors="pt")
-
- # You can now execute the model
- model.generate(inputs)
- """
-
- _hf_auto_class = AutoModel
- _pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
- _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
-
- def __init__(self, model: nn.Module, **kwargs):
- super().__init__(model)
- self.model.config.use_cache = True
- self.num_layers = model.config.num_hidden_layers
-
- @classmethod
- @with_replaced_quantizers
- def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
- """
- This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModel.
- Once the model is initialized, you can use other methods such as export, compile, and generate on the same object.
-
- Args:
- :pretrained_name_or_path (str): Model card name from HuggingFace or local path to model directory.
- :args, kwargs: Additional arguments to pass to transformers.AutoModel.
-
- .. code-block:: python
-
- from QEfficient import QEFFAutoModel
- from transformers import AutoTokenizer
-
- # Initialize the model using from_pretrained similar to transformers.AutoModel.
- model = QEFFAutoModel.from_pretrained("model_name")
-
- # Now you can directly compile the model for Cloud AI 100
- model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU
-
- #prepare input
- tokenizer = AutoTokenizer.from_pretrained(model_name)
- inputs = tokenizer("My name is", return_tensors="pt")
-
- # You can now execute the model
- model.generate(inputs)
- """
- if kwargs.get("attn_implementation", None) not in {None, "eager"}:
- logger.warning('Updating attn_implementation="eager"')
-
- if kwargs.get("low_cpu_mem_usage", None):
- logger.warning("Updating low_cpu_mem_usage=False")
-
- kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False, "add_pooling_layer": False})
- try:
- model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
- warnings.warn("Removing pooling layer from the model if exist")
- except TypeError:
- kwargs.pop("add_pooling_layer", None)
- model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
- return cls(model)
-
- @property
- def model_hash(self) -> str:
- # NOTE: model_config.to_diff_dict() has "_name_or_path" attribute which is the model card name or path.
- # Using same card name will result in same hash. But, using a relative path for one run and
- # absolute path for another run will result in different hash.
- # The added complexity to resolve different paths to same location is not worth pursuing.
- # Instead, advise the user to always provide same relative paths or absolute paths for local models.
-
- # Compute the hash with: model_config, transforms
- mhash = hashlib.sha256()
- mhash.update(to_hashable(self.model.config.to_diff_dict()))
- mhash.update(to_hashable(self._transform_names()))
- mhash = mhash.hexdigest()[:16]
- return mhash
-
- def export(self, export_dir: Optional[str] = None) -> str:
- """
- Exports the model to ``ONNX`` format using ``torch.onnx.export``.
-
- ``Optional`` Args:
- :export_dir (str, optional): The directory path to store ONNX-graph.
-
- Returns:
- :str: Path of the generated ``ONNX`` graph.
- """
- bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
- seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
-
- example_inputs = {
- "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
- "attention_mask": torch.ones((bs, seq_len), dtype=torch.int64),
- }
-
- dynamic_axes = {"input_ids": {0: "batch_size", 1: "seq_len"}, "attention_mask": {0: "batch_size", 1: "seq_len"}}
-
- output_names = ["output"]
-
- return self._export(
- example_inputs,
- output_names,
- dynamic_axes,
- export_dir=export_dir,
- )
-
- def compile(
- self,
- onnx_path: Optional[str] = None,
- compile_dir: Optional[str] = None,
- *,
- seq_len: int = 32,
- batch_size: int = 1,
- num_devices: int = 1,
- num_cores: int = 16, # FIXME: Make this mandatory arg
- mxfp6_matmul: bool = False,
- **compiler_options,
- ) -> str:
- """
- This method compiles the exported ``ONNX`` model using the Cloud AI 100 Platform SDK compiler binary found at ``/opt/qti-aic/exec/qaic-exec`` and generates a ``qpc`` package.
- If the model has not been exported yet, this method will handle the export process.
- You can pass any other arguments that the `qaic-exec` takes as extra kwargs.
-
- ``Optional`` Args:
- :onnx_path (str, optional): Path to pre-exported onnx model.
- :compile_dir (str, optional): Path for saving the qpc generated.
- :seq_len (int, optional): The length of the prompt should be less that ``seq_len``. ``Defaults to 32``.
- :batch_size (int, optional): Batch size. ``Defaults to 1``.
- :num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1.
- :num_cores (int): Number of cores used to compile the model.
- :mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``.
- :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
- :allow_mxint8_mdp_io (bool, optional): Allows MXINT8 compression of MDP IO traffic. ``Defaults to False.``
- Returns:
- :str: Path of the compiled ``qpc`` package.
- """
-
- specializations = [
- {"batch_size": batch_size, "seq_len": seq_len},
- ]
-
- return self._compile(
- onnx_path,
- compile_dir,
- compile_only=True,
- specializations=specializations,
- convert_to_fp16=True,
- mxfp6_matmul=mxfp6_matmul,
- mdp_ts_num_devices=num_devices,
- aic_num_cores=num_cores,
- **compiler_options,
- )
+ return self._compile(
+ onnx_path,
+ compile_dir,
+ compile_only=True,
+ specializations=specializations,
+ convert_to_fp16=True,
+ mxfp6_matmul=mxfp6_matmul,
+ mdp_ts_num_devices=num_devices,
+ aic_num_cores=num_cores,
+ **compiler_options,
+ )
def generate(
self,
@@ -692,50 +341,13 @@ def pytorch_feature_generate(self, model, inputs: Union[torch.Tensor, np.ndarray
return model(**inputs)
-class QeffCommomVisionEncoder(nn.Module):
- def __init__(self, model):
- super().__init__()
- self.model = model
- self.cross_attention_layers = self.model.config.get_text_config().cross_attention_layers
-
- def forward(
- self,
- pixel_values: Optional[torch.FloatTensor] = None,
- aspect_ratio_mask: Optional[torch.Tensor] = None,
- aspect_ratio_ids: Optional[torch.Tensor] = None,
- ) -> List[Tuple[torch.Tensor]]:
- vision_outputs = self.model.vision_model(
- pixel_values=pixel_values,
- aspect_ratio_ids=aspect_ratio_ids,
- aspect_ratio_mask=aspect_ratio_mask,
- )
- cross_attention_states = vision_outputs[0]
- cross_attention_states = self.model.multi_modal_projector(cross_attention_states).reshape(
- -1, cross_attention_states.shape[-2], self.model.hidden_size
- )
-
- bsz = pixel_values.shape[0]
- outputs = []
- for i in self.cross_attention_layers:
- cross_attn = self.model.language_model.model.layers[i].cross_attn
- key_states = cross_attn.k_proj(cross_attention_states)
- value_states = cross_attn.v_proj(cross_attention_states)
- key_states = key_states.view(bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim).transpose(
- 1, 2
- )
-
- outputs.append((key_states, value_states))
- return outputs
-
-
class QEffVisionEncoderForTextImageToTextModel(QEFFBaseModel):
_pytorch_transforms = [AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, CustomOpsTransform, KVCacheTransform]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
def __init__(self, model: nn.modules):
super().__init__(model)
- self.model = QeffCommomVisionEncoder(model)
+ self.model = model.get_qeff_vision_encoder()
def export(self, inputs, output_names, dynamic_axes, export_dir=None):
return self._export(inputs, output_names, dynamic_axes, export_dir)
@@ -934,7 +546,7 @@ def compile(
self.export()
print("compiling vision model")
- self.vision_model._compile(
+ self.vision_model.compile(
compile_dir,
compile_only=True,
specializations=vision_specializations,
@@ -973,13 +585,12 @@ def compile(
if output_name.startswith("past_"):
custom_io_lang[output_name] = kv_cache_dtype
- print("generating lang model")
- compiler_options.update({"retained-state": True})
- self.lang_model._compile(
+ self.lang_model.compile(
compile_dir,
compile_only=True,
specializations=lang_specializations,
convert_to_fp16=True,
+ retained_state=True,
mxfp6_matmul=mxfp6_matmul,
mdp_ts_num_devices=num_devices,
aic_num_cores=num_cores,
@@ -1136,6 +747,7 @@ class _QEFFAutoModelForImageTextToText1QPC(QEFFTransformersBase):
GPTQToMatmulNbitsTransform,
CustomOpsTransform,
KVCacheTransform,
+ KVCacheModuleMethodMapperTransform,
VlmNoKVOffloadTransorm,
]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
@@ -1147,11 +759,15 @@ def __init__(
):
if kwargs.pop("full_batch_size", None):
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
-
super().__init__(model)
- self.model.config.text_config.use_cache = True
- self.input_shapes, self.output_names = None, None
- self.num_layers = model.config.text_config.num_hidden_layers
+
+ # to handle internvl models
+ if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"):
+ self.model.config.llm_config.use_cache = True
+ self.model.config.llm_config._attn_implementation = "eager"
+ self.model.config.vision_config.use_flash_attn = "false"
+ else:
+ self.model.config.text_config.use_cache = True
@classmethod
def from_pretrained(
@@ -1188,12 +804,12 @@ def export(
def compile(
self,
- img_size: int = None,
+ img_size: Optional[int] = None,
onnx_path: Optional[str] = None,
compile_dir: Optional[str] = None,
*,
- prefill_seq_len: int = None,
- ctx_len: int = None,
+ prefill_seq_len: Optional[int] = None,
+ ctx_len: Optional[int] = None,
batch_size: int = 1,
num_devices: int = 1,
num_cores: int = 16, # FIXME: Make this mandatory arg
@@ -1204,6 +820,7 @@ def compile(
output_names = self.model.get_output_names()
# Get specializations from modelling file
+ # TODO: expose this via the auto class as well
specializations = self.model.get_specializations(
batch_size=batch_size,
prefill_seq_len=prefill_seq_len,
@@ -1212,9 +829,8 @@ def compile(
**compiler_options,
)
- kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
-
custom_io = {}
+ kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
# inputs
for input_name in output_names:
if input_name.endswith("_RetainedState"):
@@ -1233,11 +849,12 @@ def compile(
specializations=specializations,
convert_to_fp16=True,
mxfp6_matmul=mxfp6_matmul,
+ custom_io=custom_io,
mdp_ts_num_devices=num_devices,
aic_num_cores=num_cores,
- custom_io=custom_io,
**compiler_options,
)
+ return self.qpc_path
def get_onnx_dynamic_axes(self):
return self.model.get_onnx_dynamic_axes()
@@ -1247,7 +864,8 @@ def generate(
inputs: torch.Tensor,
streamer: Optional[TextStreamer] = None,
device_ids: List[int] = None,
- generation_len: int = None,
+ runtime_ai100: bool = True,
+ generation_len: Optional[int] = None,
) -> Union[torch.Tensor, np.ndarray]:
"""
This method generates output by executing PyTorch runtime or the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
@@ -1259,6 +877,12 @@ def generate(
Returns:
:dict: Output from the ``AI_100`` or ``PyTorch`` runtime.
"""
+ if not runtime_ai100:
+ raise NotImplementedError("PyTorch execution is not supported yet for this model!")
+
+ return self.cloud_ai_100_generate(
+ inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len
+ )
return self.cloud_ai_100_generate(
inputs=inputs, device_ids=device_ids, generation_len=generation_len, streamer=streamer
@@ -1407,8 +1031,11 @@ class QEFFAutoModelForImageTextToText:
_hf_auto_class = AutoModelForImageTextToText
- def __new__(cls, model, kv_offload=False, **kwargs):
- return cls._get_qeff_class(model, kv_offload, **kwargs)
+ def __new__(self, model: nn.Module, kv_offload=False, **kwargs):
+ if kv_offload:
+ return _QEffAutoModelForImageTextToText2QPC(model, **kwargs)
+ else:
+ return _QEFFAutoModelForImageTextToText1QPC(model, **kwargs)
@classmethod
@with_replaced_quantizers
@@ -1421,25 +1048,385 @@ def from_pretrained(cls, pretrained_model_name_or_path, kv_offload=False, **kwar
logger.warning("Updating low_cpu_mem_usage=False")
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
-
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
+ return cls(model, kv_offload=kv_offload, **kwargs)
+
+
+MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = {"InternVLChatModel": QEFFAutoModelForImageTextToText}
+
+
+class QEFFAutoModelForCausalLM(QEFFBaseModel):
+ """
+ The QEFF class is designed for manipulating any causal language model from the HuggingFace hub.
+ Although it is possible to initialize the class directly, we highly recommend using the ``from_pretrained`` method for initialization.
+
+ ``Mandatory`` Args:
+ :model (nn.Module): PyTorch model
+ :continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later.
+ :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode.
+
+
+ .. code-block:: python
+
+ from QEfficient import QEFFAutoModelForCausalLM
+ from transformers import AutoTokenizer
+
+ model_name = "gpt2"
+ model = QEFFAutoModelForCausalLM.from_pretrained(model_name, num_hidden_layers=2)
+ model.compile(prefill_seq_len=128, ctx_len=256, num_cores=16, num_devices=1)
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+ model.generate(prompts=["Hi there!!"], tokenizer=tokenizer)
+ """
- return cls._get_qeff_class(model, kv_offload, **kwargs)
+ _hf_auto_class = AutoModelForCausalLM
+ _pytorch_transforms = [AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, CustomOpsTransform, KVCacheTransform]
+ _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
+
+ def __init__(
+ self,
+ model: nn.Module,
+ continuous_batching: bool = False,
+ is_tlm: bool = False,
+ **kwargs,
+ ):
+ # TODO: remove from version 1.20
+ if kwargs.pop("full_batch_size", None):
+ continuous_batching = True
+ warnings.warn(
+ "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2
+ )
+ if hasattr(model.config, "quantization_config") and not isinstance(
+ model.config.quantization_config, tuple(QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.values())
+ ):
+ logger.warning(
+ "Please use `from_pretrained` method to load quantized models, might give unexpected results"
+ )
+
+ super().__init__(model)
+
+ # Set use_cache=True to get KV values as output during ONNX export
+ self.model.config.use_cache = True
+ self.num_layers = model.config.num_hidden_layers
+ self.continuous_batching = continuous_batching
+
+ if is_tlm:
+ # TODO: It is possible to always apply this transform and make value of indices as last indices by default in PyTorch
+ self.model, transformed = SpDTransform.apply(self.model)
+ self.is_tlm = is_tlm
+
+ @property
+ def model_name(self) -> str:
+ mname = self.model.__class__.__name__
+ if mname.startswith("QEff") or mname.startswith("QEFF"):
+ mname = mname[4:]
+ return mname
+
+ def __repr__(self) -> str:
+ return self.__class__.__name__ + "\n" + self.model.__repr__
@classmethod
- def _get_qeff_class(cls, model, kv_offload, **kwargs):
+ def from_pretrained(
+ cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, *args, **kwargs
+ ):
"""
- Return the appropriate QEFFAutoModelForImageTextToText subclass based on kv_offload.
+ This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM.
+ Once the model is initialized, you can use other methods such as export, compile, and generate on the same object.
Args:
- model: The model instance.
- kv_offload (bool): Whether to enable key-value offloading.
- **kwargs: Additional keyword arguments for model configuration.
+ :pretrained_name_or_path (str): Model card name from HuggingFace or local path to model directory.
+ :continuous_batching (bool): Whether this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later.
+ :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode.
+ :args, kwargs: Additional arguments to pass to transformers.AutoModelForCausalLM.
+
+ .. code-block:: python
+
+ from QEfficient import QEFFAutoModelForCausalLM
+ from transformers import AutoTokenizer
+
+ # Initialize the model using from_pretrained similar to transformers.AutoModelForCausalLM
+ model_name = "gpt2"
+ model = QEFFAutoModelForCausalLM.from_pretrained(model_name)
+
+ # Now you can directly compile the model for Cloud AI 100
+ model.compile(num_cores=16) # Considering you have a Cloud AI 100 Standard SKU
+
+ # You can now execute the model
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+ model.generate(prompts=["Hi there!!"], tokenizer=tokenizer)
+ """
+ if kwargs.pop("full_batch_size", None):
+ continuous_batching = True
+ warnings.warn(
+ "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2
+ )
+
+ if kwargs.get("attn_implementation", None) not in {None, "eager"}:
+ logger.warning('Updating attn_implementation="eager"')
+
+ if kwargs.get("low_cpu_mem_usage", None):
+ logger.warning("Updating low_cpu_mem_usage=False")
+
+ kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
+
+ kv_offload = kwargs.pop("kv_offload", None)
+ model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
+
+ if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP:
+ return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__](
+ model, kv_offload=kv_offload if kv_offload else False
+ )
+
+ return cls(model, is_tlm=is_tlm, continuous_batching=continuous_batching)
+
+ @property
+ def model_hash(self) -> str:
+ # Compute the hash with: model_config, continuous_batching, transforms
+ mhash = hashlib.sha256()
+ mhash.update(to_hashable(self.model.config.to_diff_dict()))
+ mhash.update(to_hashable({"continuous_batching": self.continuous_batching}))
+ mhash.update(to_hashable({"is_tlm": self.is_tlm}))
+ mhash.update(to_hashable(self._transform_names()))
+ mhash = mhash.hexdigest()[:16]
+ return mhash
+
+ def export(self, export_dir: Optional[str] = None) -> str:
+ """
+ Exports the model to ``ONNX`` format using ``torch.onnx.export``.
+
+ ``Optional`` Args:
+ :export_dir (str, optional): The directory path to store ONNX-graph.
Returns:
- QEFFAutoModelForImageTextToText: An instance of the appropriate QEFFAutoModelForImageTextToText subclass.
+ :str: Path of the generated ``ONNX`` graph.
"""
- if kv_offload:
- return _QEffAutoModelForImageTextToText2QPC(model, **kwargs)
+ bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
+ seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
+ fbs = constants.ONNX_EXPORT_EXAMPLE_FBS
+ kv_cache_shape = get_padding_shape_from_config(
+ self.model.config, fbs if self.continuous_batching else bs, seq_len
+ )
+ example_inputs = {
+ "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
+ "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1),
+ "past_key_values": [[] for _ in range(self.num_layers)],
+ }
+ dynamic_axes = {
+ "input_ids": {0: "batch_size", 1: "seq_len"},
+ "position_ids": {0: "batch_size", 1: "seq_len"},
+ }
+ if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d
+ pkv_dynamic_axes = {
+ 0: "full_batch_size" if self.continuous_batching else "batch_size",
+ 1: "ctx_len",
+ }
+ else: # pkv is 4d
+ pkv_dynamic_axes = {
+ 0: "full_batch_size" if self.continuous_batching else "batch_size",
+ 2: "ctx_len",
+ }
+ output_names = ["logits"]
+
+ for i in range(self.num_layers):
+ for kv in ["key", "value"]:
+ example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
+ dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
+ output_names.append(f"past_{kv}.{i}_RetainedState")
+
+ if self.continuous_batching:
+ example_inputs["batch_index"] = torch.arange(bs).view(bs, 1)
+ dynamic_axes["batch_index"] = {0: "batch_size"}
+
+ if self.is_tlm:
+ nlk = constants.ONNX_EXPORT_EXAMPLE_NLK # Number of Logits to Keep
+ example_inputs["num_logits_to_keep"] = torch.arange(nlk).view(nlk, 1)
+ dynamic_axes["num_logits_to_keep"] = {0: "num_logits_to_keep"}
+
+ return self._export(
+ example_inputs,
+ output_names,
+ dynamic_axes,
+ export_dir=export_dir,
+ )
+
+ def compile(
+ self,
+ onnx_path: Optional[str] = None,
+ compile_dir: Optional[str] = None,
+ *,
+ prefill_seq_len: int = 32,
+ ctx_len: int = 128,
+ batch_size: int = 1,
+ full_batch_size: Optional[int] = None,
+ kv_cache_batch_size: Optional[int] = None,
+ num_devices: int = 1,
+ num_cores: int = 16, # FIXME: Make this mandatory arg
+ mxfp6_matmul: bool = False,
+ mxint8_kv_cache: bool = False,
+ num_speculative_tokens: Optional[int] = None,
+ enable_qnn: bool = False,
+ qnn_config: Optional[str] = None,
+ **compiler_options,
+ ) -> str:
+ """
+ This method compiles the exported ``ONNX`` model using the Cloud AI 100 Platform SDK compiler binary found at ``/opt/qti-aic/exec/qaic-exec`` and generates a ``qpc`` package.
+ If the model has not been exported yet, this method will handle the export process.
+ You can pass any other arguments that the `qaic-exec` takes as extra kwargs.
+
+ ``Optional`` Args:
+ :onnx_path (str, optional): Path to pre-exported onnx model.
+ :compile_dir (str, optional): Path for saving the qpc generated.
+ :num_cores (int): Number of cores used to compile the model.
+ :num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1.
+ :batch_size (int, optional): Batch size. ``Defaults to 1``.
+ :prefill_seq_len (int, optional): The length of the Prefill prompt should be less that ``prefill_seq_len``. ``Defaults to 32``.
+ :ctx_len (int, optional): Maximum ``ctx`` that the compiled model can remember. ``Defaults to 128``.
+ :full_batch_size (int, optional): Continuous batching batch size.
+ :mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``.
+ :mxint8_kv_cache (bool, optional): Whether to use ``mxint8`` compression for KV cache. ``Defaults to False``.
+ :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model.
+ :mos (int, optional): Effort level to reduce on-chip memory. Defaults to -1, meaning no effort. ``Defaults to -1``.
+ :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
+ :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
+ :qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
+
+ Returns:
+ :str: Path of the compiled ``qpc`` package.
+ """
+ if self.is_tlm:
+ # assert num_speculative_tokens cfg is acceptable if defined
+ if num_speculative_tokens is None:
+ raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` is True.")
+ if not isinstance(num_speculative_tokens, int) and num_speculative_tokens < 2:
+ ValueError(
+ f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}"
+ )
+ num_logits_to_keep = num_speculative_tokens + 1
+ if prefill_seq_len < num_logits_to_keep:
+ raise ValueError(
+ f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})"
+ )
+
+ if self.continuous_batching and full_batch_size is None:
+ raise TypeError("missing required argument: 'full_batch_size'")
+
+ if kv_cache_batch_size and not full_batch_size:
+ raise ValueError(
+ "Prefix caching is enabled only for continuous batching as of now. Please pass `full_batch_size` argument and make sure you pass `continuous_batching=True` in the `from_pretrained` call"
+ )
+
+ kv_cache_batch_size = (
+ kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size)
+ )
+ # Define prefill specialization
+ prefill_specialization = {
+ # Prefill is always run with single BS for continuous batching.
+ "batch_size": 1 if self.continuous_batching else batch_size,
+ "seq_len": prefill_seq_len,
+ "ctx_len": ctx_len,
+ # TODO: should be renamed to kv_cache_batch_size in specialization too
+ }
+ prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ...
+ if self.continuous_batching:
+ prefill_specialization.update({"full_batch_size": kv_cache_batch_size})
else:
- return _QEFFAutoModelForImageTextToText1QPC(model, **kwargs)
+ prefill_specialization.update({"batch_size": kv_cache_batch_size})
+ prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ...
+ specializations = [
+ prefill_specialization,
+ ]
+
+ # Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization
+ if prefill_seq_len != 1 or self.continuous_batching:
+ decode_specialization = {
+ "batch_size": full_batch_size if self.continuous_batching else batch_size,
+ "seq_len": num_speculative_tokens + 1 if self.is_tlm else 1,
+ "ctx_len": ctx_len,
+ }
+ if self.continuous_batching:
+ decode_specialization.update({"full_batch_size": kv_cache_batch_size})
+ else:
+ decode_specialization.update({"batch_size": kv_cache_batch_size})
+ decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ...
+ specializations.append(decode_specialization)
+
+ if enable_qnn:
+ if compiler_options:
+ logger.warning("Extra arguments to QNN compilation are supported via qnn_config.json only")
+
+ qpc_path = self._qnn_compile(
+ onnx_path,
+ compile_dir,
+ specializations=specializations,
+ prefill_seq_len=prefill_seq_len,
+ ctx_len=ctx_len,
+ batch_size=batch_size,
+ full_batch_size=full_batch_size,
+ mdp_ts_num_devices=num_devices,
+ num_cores=num_cores,
+ mxfp6_matmul=mxfp6_matmul,
+ mxint8_kv_cache=mxint8_kv_cache,
+ qnn_config=qnn_config,
+ )
+ else:
+ # Custom IO
+ custom_io = {}
+ kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
+ for suffix in ["", "_RetainedState"]:
+ for i in range(self.num_layers):
+ for kv in ["key", "value"]:
+ custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype
+
+ qpc_path = self._compile(
+ onnx_path,
+ compile_dir,
+ compile_only=True,
+ retained_state=True,
+ specializations=specializations,
+ convert_to_fp16=True,
+ mxfp6_matmul=mxfp6_matmul,
+ custom_io=custom_io,
+ mdp_ts_num_devices=num_devices,
+ num_speculative_tokens=num_speculative_tokens,
+ aic_num_cores=num_cores,
+ **compiler_options,
+ )
+ return qpc_path
+
+ # FIXME: Update this method to match with transformers AutoModelForCausalLM.generate
+ def generate(
+ self,
+ tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer],
+ prompts: List[str],
+ device_id: List[int] = None,
+ runtime_ai100: bool = True,
+ **kwargs,
+ ):
+ """
+ This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
+ This is a sequential execution based on the ``batch_size`` of the compiled model and the number of prompts passed.
+ If the number of prompts cannot be divided by the ``batch_size``, the last unfulfilled batch will be dropped.
+
+ ``Mandatory`` Args:
+ :tokenizer (Union[PreTrainedTokenizerFast, PreTrainedTokenizer]): Pass tokenizer of the model.
+ :prompts (List[str]): List of prompts to run the execution.
+
+ ``optional`` Args:
+ :device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
+ :runtime_ai100 (bool, optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime.
+
+ """
+ if runtime_ai100:
+ if not isinstance(self.qpc_path, Path):
+ raise TypeError("Please run compile API first!")
+ generation_len = kwargs.pop("generation_len", None)
+ return QEfficient.cloud_ai_100_exec_kv(
+ tokenizer,
+ self.qpc_path,
+ prompt=prompts,
+ device_id=device_id,
+ generation_len=generation_len,
+ is_tlm=self.is_tlm,
+ )
+ else:
+ raise NotImplementedError("Only AI_100 runtime is supported right now via generate API")
diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py
index 4ae62da49..975009f8f 100644
--- a/QEfficient/transformers/models/pytorch_transforms.py
+++ b/QEfficient/transformers/models/pytorch_transforms.py
@@ -104,7 +104,7 @@
Starcoder2Model,
)
-from QEfficient.base.pytorch_transforms import ModuleMappingTransform
+from QEfficient.base.pytorch_transforms import ModuleMappingTransform, ModuleMethodMapperTransform
from QEfficient.customop import CustomRMSNormAIC, GemmaCustomRMSNormAIC
from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.models.codegen.modeling_codegen import (
@@ -149,6 +149,7 @@
QEffGPTJForCausalLM,
QEffGPTJModel,
)
+from QEfficient.transformers.models.internvl.modeling_internvl import QEffInternVisionEmbeddings, QEffInternVLModel
from QEfficient.transformers.models.llama.modeling_llama import (
QEffLlamaAttention,
QEffLlamaDecoderLayer,
@@ -378,3 +379,17 @@ class VlmNoKVOffloadTransorm(ModuleMappingTransform):
# Llama
MllamaTextCrossAttention: QEffMllamaTextCrossAttentionSingleQPC,
}
+
+
+class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform):
+ _match_string_replace_method = {
+ "InternVLChatModel": {
+ "forward": QEffInternVLModel.forward,
+ "get_dummy_inputs": QEffInternVLModel.get_dummy_inputs,
+ "get_specializations": QEffInternVLModel.get_specializations,
+ "get_onnx_dynamic_axes": QEffInternVLModel.get_onnx_dynamic_axes,
+ "get_output_names": QEffInternVLModel.get_output_names,
+ },
+ "InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward},
+ }
+ _match_class_replace_method = {}
diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py
index 1eba0e2e6..c052a5cb6 100644
--- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py
+++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py
@@ -19,19 +19,141 @@
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
+ Qwen2Config,
Qwen2DecoderLayer,
Qwen2ForCausalLM,
Qwen2Model,
- apply_rotary_pos_emb,
+ Qwen2RotaryEmbedding,
logger,
repeat_kv,
+ rotate_half,
)
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
+# Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology
+class QEffQwen2RotaryEmbedding(Qwen2RotaryEmbedding):
+ """
+ Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
+ The only differences are:
+ - Add static sin/cos computations.
+ """
+
+ def __init__(
+ self,
+ dim=None,
+ max_position_embeddings=2048,
+ base=10000,
+ device=None,
+ scaling_factor=1.0,
+ rope_type="default",
+ config: Optional[Qwen2Config] = None,
+ ):
+ super(Qwen2RotaryEmbedding, self).__init__() # Initialize nn.Module
+ # TODO (joao): remove the `if` below, only used for BC
+ self.rope_kwargs = {}
+ if config is None:
+ logger.warning_once(
+ "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
+ "`config` argument. All other arguments will be removed in v4.45"
+ )
+ self.rope_kwargs = {
+ "rope_type": rope_type,
+ "factor": scaling_factor,
+ "dim": dim,
+ "base": base,
+ "max_position_embeddings": max_position_embeddings,
+ }
+ self.rope_type = rope_type
+ self.max_seq_len_cached = max_position_embeddings
+ self.original_max_seq_len = max_position_embeddings
+ else:
+ # BC: "rope_type" was originally "type"
+ if config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ # Build here to make `torch.jit.trace` work.
+ self._set_cos_sin_cache(
+ seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
+ )
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
+
+ freqs = torch.outer(t, self.inv_freq)
+
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if seq_len > self.max_seq_len_cached:
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+
+ return (
+ self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling,
+ self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling,
+ )
+
+
+def apply_qeff_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
+
+ Explanation:
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately.
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
+ difference with modern LLMs.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`):
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
+ used to pass offsetted position ids when working with a KV-cache.
+ mrope_section(`List(int)`):
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+
+ return q_embed.to(q.dtype), k_embed.to(k.dtype)
+
+
class QEffQwen2Attention(Qwen2Attention):
"""
Copied from Qwen2Attention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py
@@ -39,6 +161,20 @@ class QEffQwen2Attention(Qwen2Attention):
- add new args position idx for the cache_kwargs for kv retention
"""
+ def __init__(self, config, layer_idx=None):
+ super().__init__(config, layer_idx)
+ # Define the general __qeff_init__() for any changes in the init calls
+ # Set the init in the module mapping pytorch transforms
+ self.config = config
+ self.__qeff_init__()
+
+ def __qeff_init__(self):
+ self.rotary_emb = QEffQwen2RotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+
def forward(
self,
hidden_states: torch.Tensor,
@@ -71,18 +207,8 @@ def forward(
)
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
- if position_embeddings is None:
- logger.warning_once(
- "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
- "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
- "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
- "removed and `position_embeddings` will be mandatory."
- )
- cos, sin = self.rotary_emb(value_states, position_ids)
- else:
- cos, sin = position_embeddings
-
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_qeff_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# Update the cache_kwargs with position_ids for Cloud AI 100
diff --git a/tests/transformers/models/test_image_text_to_text_intern.py b/tests/transformers/models/test_image_text_to_text_intern.py
new file mode 100644
index 000000000..c5b3ade66
--- /dev/null
+++ b/tests/transformers/models/test_image_text_to_text_intern.py
@@ -0,0 +1,236 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import numpy as np
+import pytest
+import torch
+import torch.nn as nn
+import torchvision.transforms as T
+from PIL import Image
+from torchvision.transforms.functional import InterpolationMode
+from transformers import AutoConfig, AutoTokenizer, TextStreamer
+
+from QEfficient import QEFFAutoModelForCausalLM
+from tests.transformers.models.conversation import get_conv_template
+
+IMAGENET_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_STD = (0.229, 0.224, 0.225)
+
+
+class InternProcessor:
+ def __init__(self, model: nn.Module, tokenizer):
+ self.model = model
+ image_size = self.model.config.force_image_size or self.model.config.vision_config.image_size
+ patch_size = self.model.config.vision_config.patch_size
+ self.template = model.config.template
+ self.conv_template = get_conv_template(self.template)
+ self.system_message = self.conv_template.system_message
+ self.num_image_token = int((image_size // patch_size) ** 2 * (self.model.config.downsample_ratio**2))
+ self.tokenizer = tokenizer
+
+ def build_transform(self, input_size):
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
+ transform = T.Compose(
+ [
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(mean=MEAN, std=STD),
+ ]
+ )
+ return transform
+
+ def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
+ best_ratio_diff = float("inf")
+ best_ratio = (1, 1)
+ area = width * height
+ for ratio in target_ratios:
+ target_aspect_ratio = ratio[0] / ratio[1]
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
+ if ratio_diff < best_ratio_diff:
+ best_ratio_diff = ratio_diff
+ best_ratio = ratio
+ elif ratio_diff == best_ratio_diff:
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
+ best_ratio = ratio
+ return best_ratio
+
+ def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
+ orig_width, orig_height = image.size
+ aspect_ratio = orig_width / orig_height
+ # calculate the existing image aspect ratio
+ target_ratios = set(
+ (i, j)
+ for n in range(min_num, max_num + 1)
+ for i in range(1, n + 1)
+ for j in range(1, n + 1)
+ if i * j <= max_num and i * j >= min_num
+ )
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+ # find the closest aspect ratio to the target
+ target_aspect_ratio = self.find_closest_aspect_ratio(
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
+ )
+ # calculate the target width and height
+ target_width = image_size * target_aspect_ratio[0]
+ target_height = image_size * target_aspect_ratio[1]
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
+ # resize the image
+ resized_img = image.resize((target_width, target_height))
+ processed_images = []
+ for i in range(blocks):
+ box = (
+ (i % (target_width // image_size)) * image_size,
+ (i // (target_width // image_size)) * image_size,
+ ((i % (target_width // image_size)) + 1) * image_size,
+ ((i // (target_width // image_size)) + 1) * image_size,
+ )
+ # split the image
+ split_img = resized_img.crop(box)
+ processed_images.append(split_img)
+ assert len(processed_images) == blocks
+ if use_thumbnail and len(processed_images) != 1:
+ thumbnail_img = image.resize((image_size, image_size))
+ processed_images.append(thumbnail_img)
+ return processed_images
+
+ def load_image(self, image_file, input_size=448, max_num=12):
+ image = Image.open(image_file).convert("RGB")
+ transform = self.build_transform(input_size=input_size)
+ images = self.dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
+ pixel_values = [transform(image) for image in images]
+ pixel_values = torch.stack(pixel_values)
+ return pixel_values
+
+ def __call__(
+ self,
+ pixel_values,
+ question,
+ history=None,
+ return_history=False,
+ num_patches_list=None,
+ IMG_START_TOKEN="
",
+ IMG_END_TOKEN="",
+ IMG_CONTEXT_TOKEN="",
+ verbose=False,
+ ) -> str:
+ if history is None and pixel_values is not None and "" not in question:
+ question = "\n" + question
+ if num_patches_list is None:
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
+ img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
+ self.model.img_context_token_id = img_context_token_id
+ template = get_conv_template(self.template)
+ template.system_message = self.system_message
+ history = [] if history is None else history
+ for old_question, old_answer in history:
+ template.append_message(template.roles[0], old_question)
+ template.append_message(template.roles[1], old_answer)
+ template.append_message(template.roles[0], question)
+ template.append_message(template.roles[1], None)
+ query = template.get_prompt()
+ if verbose and pixel_values is not None:
+ image_bs = pixel_values.shape[0]
+ print(f"dynamic ViT batch size: {image_bs}")
+ for num_patches in num_patches_list:
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
+ query = query.replace("", image_tokens, 1)
+ return query
+
+
+@pytest.mark.on_qaic
+def test_image_text_to_text_intern():
+ model_name = "OpenGVLab/InternVL2_5-1B"
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) # noqa: F841
+ config.llm_config.num_hidden_layers = 1
+ config.vision_config.num_hidden_layers = 1
+ model = QEFFAutoModelForCausalLM.from_pretrained(
+ model_name, kv_offload=False, config=config, trust_remote_code=True
+ )
+ # model = QEFFAutoModelForCausalLM.from_pretrained(model_name, kv_offload=False, trust_remote_code=True)
+
+ model.export()
+ model.compile(num_cores=14)
+
+ ### Pytorch execution
+ qeff_pt_model = model.model
+
+ prompt = "Please describe the image and generate a short story around it"
+ ctx_len = 4096
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
+
+ internProcessor = InternProcessor(qeff_pt_model, tokenizer)
+ pixel_values = internProcessor.load_image(
+ "/local/mnt/workspace/open-source/efficient-transformers/image1.jpg", max_num=12
+ )
+ question = "\n" + prompt
+ query = internProcessor(pixel_values, question)
+ pad_inputs = tokenizer(query, return_tensors="pt", padding="max_length", max_length=3840, padding_side="right")
+
+ inputs = tokenizer(query, return_tensors="pt")
+ inputs = dict(inputs)
+
+ batch_size, prompt_len = inputs["input_ids"].shape
+ inputs["pixel_values"] = pixel_values.clone()
+ pad_inputs["pixel_values"] = pixel_values.clone()
+ import copy # noqa: E402
+
+ orig_inputs = copy.deepcopy(pad_inputs)
+ inputs["position_ids"] = torch.arange(prompt_len).view(1, -1)
+ inputs.pop("attention_mask")
+
+ head_dim = (
+ qeff_pt_model.language_model.config.hidden_size // qeff_pt_model.language_model.config.num_attention_heads
+ )
+ inputs["past_key_values"] = [
+ tuple(
+ [
+ torch.zeros(
+ batch_size,
+ qeff_pt_model.language_model.config.num_key_value_heads,
+ ctx_len,
+ head_dim,
+ dtype=torch.float32,
+ )
+ for _ in range(2)
+ ]
+ )
+ for _ in range(qeff_pt_model.language_model.config.num_hidden_layers)
+ ]
+
+ streamer = TextStreamer(tokenizer)
+ generation_len = 10
+ generated_ids = np.full((batch_size, generation_len + 1), tokenizer.pad_token_id)
+ pt_outputs = qeff_pt_model(**inputs)
+ inputs["input_ids"] = pt_outputs[0].argmax(2)
+ inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1
+ streamer.put(inputs["input_ids"])
+ generated_ids[:, 0] = inputs["input_ids"].squeeze(1)
+ finished_sequences = inputs["input_ids"] == tokenizer.eos_token_id
+ for i in range(1, generation_len):
+ outputs = qeff_pt_model(**inputs)
+ inputs["input_ids"] = outputs[0].argmax(2)
+ print(inputs["input_ids"])
+ # print(tokenizer.decode(inputs["input_ids"]))
+ inputs["position_ids"] += 1
+ generated_ids[:, i] = inputs["input_ids"].squeeze(1)
+ finished_sequences |= inputs["input_ids"] == tokenizer.eos_token_id
+ if finished_sequences.all():
+ break
+
+ streamer.end()
+
+ generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ print(generated_texts)
+
+ exec_info = model.generate(inputs=orig_inputs, generation_len=128)
+ print(exec_info)
+ generated_ids_aic = exec_info.generated_ids
+ print(generated_ids_aic)
+ generated_texts = tokenizer.batch_decode(generated_ids_aic, skip_special_tokens=True)
+ print(generated_texts)