diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 38ce0ca42..f4a1ed896 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -24,14 +24,16 @@ QEFFAutoModelForCausalLM, QEFFAutoModelForCTC, QEFFAutoModelForImageTextToText, + QEFFAutoModelForMultimodalLM, QEFFAutoModelForSequenceClassification, QEFFAutoModelForSpeechSeq2Seq, QEFFCommonLoader, ) from QEfficient.compile.compile_helper import compile -from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline -from QEfficient.diffusers.pipelines.wan.pipeline_wan import QEffWanPipeline -from QEfficient.diffusers.pipelines.wan.pipeline_wan_i2v import QEffWanImageToVideoPipeline + +# from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline +# from QEfficient.diffusers.pipelines.wan.pipeline_wan import QEffWanPipeline +# from QEfficient.diffusers.pipelines.wan.pipeline_wan_i2v import QEffWanImageToVideoPipeline from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv from QEfficient.peft import QEffAutoPeftModelForCausalLM @@ -55,6 +57,7 @@ "QEFFAutoModelForCTC", "QEffAutoPeftModelForCausalLM", "QEFFAutoModelForImageTextToText", + "QEFFAutoModelForMultimodalLM", "QEFFAutoModelForSequenceClassification", "QEFFAutoModelForSpeechSeq2Seq", "QEFFCommonLoader", diff --git a/QEfficient/base/__init__.py b/QEfficient/base/__init__.py index 8462d8356..8b8f4b4a5 100644 --- a/QEfficient/base/__init__.py +++ b/QEfficient/base/__init__.py @@ -11,6 +11,7 @@ QEFFAutoModelForCausalLM, QEFFAutoModelForCTC, QEFFAutoModelForImageTextToText, + QEFFAutoModelForMultimodalLM, QEFFAutoModelForSequenceClassification, QEFFAutoModelForSpeechSeq2Seq, ) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index e8d9e004c..fe3002777 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -10,7 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple import torch -from transformers.cache_utils import Cache, CacheLayerMixin, EncoderDecoderCache, HybridCache, HybridChunkedCache +from transformers.cache_utils import Cache, CacheLayerMixin, EncoderDecoderCache from QEfficient.customop import ( CtxGatherFunc, @@ -510,211 +510,211 @@ def to_legacy_cache(self): return legacy_cache -# TODO:This function will be depercated in future. -class QEffHybridCache(HybridCache): - def __init__(self, config, batch_size, max_cache_len): - super().__init__(config, batch_size, max_cache_len=max_cache_len) - self.key_cache: List[torch.Tensor] = [] - self.value_cache: List[torch.Tensor] = [] - - @classmethod - def from_legacy_cache( - cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - ) -> "HybridCache": - """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for - backward compatibility.""" - cache = cls(config, batch_size=past_key_values[0][0].shape[0], max_cache_len=past_key_values[0][0].shape[2]) - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) - return cache - - def __len__(self): - """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds - to the number of layers in the model. - """ - return len(self.key_cache) - - def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - is_empty_layer = ( - len(self.key_cache) == 0 # no cache in any layer - or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it - or len(self.key_cache[layer_idx]) == 0 # the layer has no cache - ) - layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 - return layer_seq_length - - def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for - backward compatibility.""" - legacy_cache = () - for layer_idx in range(len(self)): - legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) - return legacy_cache - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - if len(self.key_cache) <= layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - k_out, v_out = key_states, value_states - else: - position_ids = cache_kwargs.get("position_ids") - sliding_window_pattern = cache_kwargs.get("sliding_window_pattern") - is_sliding_layer = torch.tensor(bool((layer_idx + 1) % sliding_window_pattern)) - layer_ctx_len = self.key_cache[layer_idx].shape[2] - kv_position_ids = torch.where( - (~is_sliding_layer | (position_ids == -1)), position_ids, position_ids % (layer_ctx_len - 1) - ) - - kv_position_ids = torch.where( - is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1) * 2), - (position_ids + 1) % layer_ctx_len, - kv_position_ids, - ) - - valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) - key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) - value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( - self.value_cache[layer_idx], kv_position_ids, value_states - ) - k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] - - # Original Gather - ctx_len = cache_kwargs.get("CCL", self.key_cache[layer_idx].shape[2]) - ctx_indices = torch.arange(ctx_len)[None, None, ...] - gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) - invalid_mask = ctx_indices > gather_limit - invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() - ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - - all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 - rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices) - rolling_indices = rolling_indices[:ctx_len] - final_indices = torch.where( - (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices - ) - k_out = CtxGatherFunc.apply(k_out, final_indices, ctx_len) - v_out = CtxGatherFunc.apply(v_out, final_indices, ctx_len) - ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) - v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) - return k_out, v_out - - -# TODO:This function will be depercated in future. -class QEffHybridChunkedCache(HybridChunkedCache): - def __len__(self): - """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds - to the number of layers in the model. - """ - return len(self.key_cache) - - def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - is_empty_layer = ( - len(self.key_cache) == 0 # no cache in any layer - or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it - or len(self.key_cache[layer_idx]) == 0 # the layer has no cache - ) - layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 - return layer_seq_length - - def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - """Converts the `HybridChunkedCache` instance into the its equivalent in the legacy cache format. Used for - backward compatibility.""" - legacy_cache = () - for layer_idx in range(len(self)): - legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) - return legacy_cache - - @classmethod - def from_legacy_cache( - cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - ) -> "HybridChunkedCache": - """Converts a cache in the legacy cache format into an equivalent `HybridChunkedCache`. Used for - backward compatibility.""" - cache = cls(config, max_batch_size=past_key_values[0][0].shape[0], max_cache_len=past_key_values[0][0].shape[2]) - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) - return cache - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if len(self.key_cache) <= layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - k_out, v_out = key_states, value_states - - else: - position_ids = cache_kwargs.get("position_ids") - is_sliding_layer = torch.tensor(bool(self.is_sliding[layer_idx])) - - # Update the position_ids to handle the sliding window - layer_ctx_len = self.key_cache[layer_idx].shape[2] - kv_position_ids = torch.where( - (~is_sliding_layer | (position_ids == -1)), position_ids, position_ids % (layer_ctx_len - 1) - ) - - kv_position_ids = torch.where( - is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1) * 2), - (position_ids + 1) % layer_ctx_len, - kv_position_ids, - ) - - valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) - key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) - value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( - self.value_cache[layer_idx], kv_position_ids, value_states - ) - k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] - - # Original Gather - ctx_len = cache_kwargs.get("CCL", k_out.shape[2]) - ctx_len = min(layer_ctx_len, ctx_len) - ctx_indices = torch.arange(ctx_len)[None, None, ...] - gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) - invalid_mask = ctx_indices > gather_limit - if torch.onnx.is_in_onnx_export(): - invalid_idx_value = torch.iinfo(torch.int32).max - else: - invalid_idx_value = 0 - ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - - # Rolling indices for sliding window - all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 - rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices) - rolling_indices = rolling_indices[:ctx_len] - final_indices = torch.where( - (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices - ) - k_out = CtxGatherFunc.apply(k_out, final_indices, ctx_len) - v_out = CtxGatherFunc.apply(v_out, final_indices, ctx_len) - ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) - v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) - return k_out, v_out +# # TODO:This function will be depercated in future. +# class QEffHybridCache(HybridCache): +# def __init__(self, config, batch_size, max_cache_len): +# super().__init__(config, batch_size, max_cache_len=max_cache_len) +# self.key_cache: List[torch.Tensor] = [] +# self.value_cache: List[torch.Tensor] = [] + +# @classmethod +# def from_legacy_cache( +# cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None +# ) -> "HybridCache": +# """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for +# backward compatibility.""" +# cache = cls(config, batch_size=past_key_values[0][0].shape[0], max_cache_len=past_key_values[0][0].shape[2]) +# if past_key_values is not None: +# for layer_idx in range(len(past_key_values)): +# key_states, value_states = past_key_values[layer_idx] +# cache.update(key_states, value_states, layer_idx) +# return cache + +# def __len__(self): +# """ +# Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds +# to the number of layers in the model. +# """ +# return len(self.key_cache) + +# def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int: +# """Returns the sequence length of the cached states. A layer index can be optionally passed.""" +# # TODO: deprecate this function in favor of `cache_position` +# is_empty_layer = ( +# len(self.key_cache) == 0 # no cache in any layer +# or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it +# or len(self.key_cache[layer_idx]) == 0 # the layer has no cache +# ) +# layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 +# return layer_seq_length + +# def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: +# """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for +# backward compatibility.""" +# legacy_cache = () +# for layer_idx in range(len(self)): +# legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) +# return legacy_cache + +# def update( +# self, +# key_states: torch.Tensor, +# value_states: torch.Tensor, +# layer_idx: int, +# cache_kwargs: Optional[Dict[str, Any]] = None, +# ) -> Tuple[torch.Tensor, torch.Tensor]: +# if len(self.key_cache) <= layer_idx: +# self.key_cache.append(key_states) +# self.value_cache.append(value_states) +# k_out, v_out = key_states, value_states +# else: +# position_ids = cache_kwargs.get("position_ids") +# sliding_window_pattern = cache_kwargs.get("sliding_window_pattern") +# is_sliding_layer = torch.tensor(bool((layer_idx + 1) % sliding_window_pattern)) +# layer_ctx_len = self.key_cache[layer_idx].shape[2] +# kv_position_ids = torch.where( +# (~is_sliding_layer | (position_ids == -1)), position_ids, position_ids % (layer_ctx_len - 1) +# ) + +# kv_position_ids = torch.where( +# is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1) * 2), +# (position_ids + 1) % layer_ctx_len, +# kv_position_ids, +# ) + +# valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) +# key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) +# value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) +# self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) +# self.value_cache[layer_idx] = CtxScatterFunc.apply( +# self.value_cache[layer_idx], kv_position_ids, value_states +# ) +# k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + +# # Original Gather +# ctx_len = cache_kwargs.get("CCL", self.key_cache[layer_idx].shape[2]) +# ctx_indices = torch.arange(ctx_len)[None, None, ...] +# gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) +# invalid_mask = ctx_indices > gather_limit +# invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() +# ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + +# all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 +# rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices) +# rolling_indices = rolling_indices[:ctx_len] +# final_indices = torch.where( +# (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices +# ) +# k_out = CtxGatherFunc.apply(k_out, final_indices, ctx_len) +# v_out = CtxGatherFunc.apply(v_out, final_indices, ctx_len) +# ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) +# v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) +# return k_out, v_out + + +# # TODO:This function will be depercated in future. +# class QEffHybridChunkedCache(HybridChunkedCache): +# def __len__(self): +# """ +# Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds +# to the number of layers in the model. +# """ +# return len(self.key_cache) + +# def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int: +# """Returns the sequence length of the cached states. A layer index can be optionally passed.""" +# # TODO: deprecate this function in favor of `cache_position` +# is_empty_layer = ( +# len(self.key_cache) == 0 # no cache in any layer +# or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it +# or len(self.key_cache[layer_idx]) == 0 # the layer has no cache +# ) +# layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 +# return layer_seq_length + +# def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: +# """Converts the `HybridChunkedCache` instance into the its equivalent in the legacy cache format. Used for +# backward compatibility.""" +# legacy_cache = () +# for layer_idx in range(len(self)): +# legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) +# return legacy_cache + +# @classmethod +# def from_legacy_cache( +# cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None +# ) -> "HybridChunkedCache": +# """Converts a cache in the legacy cache format into an equivalent `HybridChunkedCache`. Used for +# backward compatibility.""" +# cache = cls(config, max_batch_size=past_key_values[0][0].shape[0], max_cache_len=past_key_values[0][0].shape[2]) +# if past_key_values is not None: +# for layer_idx in range(len(past_key_values)): +# key_states, value_states = past_key_values[layer_idx] +# cache.update(key_states, value_states, layer_idx) +# return cache + +# def update( +# self, +# key_states: torch.Tensor, +# value_states: torch.Tensor, +# layer_idx: int, +# cache_kwargs: Optional[Dict[str, Any]] = None, +# ) -> Tuple[torch.Tensor, torch.Tensor]: +# # Update the cache +# if len(self.key_cache) <= layer_idx: +# self.key_cache.append(key_states) +# self.value_cache.append(value_states) +# k_out, v_out = key_states, value_states + +# else: +# position_ids = cache_kwargs.get("position_ids") +# is_sliding_layer = torch.tensor(bool(self.is_sliding[layer_idx])) + +# # Update the position_ids to handle the sliding window +# layer_ctx_len = self.key_cache[layer_idx].shape[2] +# kv_position_ids = torch.where( +# (~is_sliding_layer | (position_ids == -1)), position_ids, position_ids % (layer_ctx_len - 1) +# ) + +# kv_position_ids = torch.where( +# is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1) * 2), +# (position_ids + 1) % layer_ctx_len, +# kv_position_ids, +# ) + +# valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) +# key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) +# value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) +# self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) +# self.value_cache[layer_idx] = CtxScatterFunc.apply( +# self.value_cache[layer_idx], kv_position_ids, value_states +# ) +# k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + +# # Original Gather +# ctx_len = cache_kwargs.get("CCL", k_out.shape[2]) +# ctx_len = min(layer_ctx_len, ctx_len) +# ctx_indices = torch.arange(ctx_len)[None, None, ...] +# gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) +# invalid_mask = ctx_indices > gather_limit +# if torch.onnx.is_in_onnx_export(): +# invalid_idx_value = torch.iinfo(torch.int32).max +# else: +# invalid_idx_value = 0 +# ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + +# # Rolling indices for sliding window +# all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 +# rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices) +# rolling_indices = rolling_indices[:ctx_len] +# final_indices = torch.where( +# (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices +# ) +# k_out = CtxGatherFunc.apply(k_out, final_indices, ctx_len) +# v_out = CtxGatherFunc.apply(v_out, final_indices, ctx_len) +# ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) +# v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) +# return k_out, v_out # This is a hack for now, until we get to merging this code with HybridCache class, @@ -729,9 +729,7 @@ def __init__(self, config, batch_size, max_cache_len, sliding_window_len): self.value_cache: List[torch.Tensor] = [] @classmethod - def from_legacy_cache( - cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - ) -> "HybridCache": + def from_legacy_cache(cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None): """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for backward compatibility.""" @@ -855,9 +853,7 @@ def __init__(self, config, batch_size, max_cache_len, sliding_window_len): self.value_cache: List[torch.Tensor] = [] @classmethod - def from_legacy_cache( - cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - ) -> "HybridCache": + def from_legacy_cache(cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None): """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for backward compatibility.""" cache = cls( diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 5f0eaf2b7..63de54200 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -20,6 +20,7 @@ AutoModelForCausalLM, AutoModelForCTC, AutoModelForImageTextToText, + AutoModelForMultimodalLM, AutoModelForSequenceClassification, AutoModelForSpeechSeq2Seq, PreTrainedTokenizer, @@ -1001,6 +1002,108 @@ def get_model_config(self) -> dict: return self.model.model.config.__dict__ +class QEffAudioEncoderForMultimodalLM(QEFFBaseModel): + """ + QEfficient wrapper for the audio encoder component of a multimodal LM. + """ + + _pytorch_transforms = [ + AwqToMatmulNbitsTransform, + GPTQToMatmulNbitsTransform, + CustomOpsTransform, + KVCacheTransform, + KVCacheExternalModuleMapperTransform, + ] + _onnx_transforms = [] + + def __init__(self, model: nn.modules, **kwargs): + _configure_proxy_for_model(self, kwargs.pop("enable_proxy", False)) + super().__init__(model, **kwargs) + self.model = model.get_qeff_audio_encoder() + self.hash_params["qeff_auto_class"] = self.__class__.__name__ + + def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs): + return self._export( + inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + offload_pt_weights=offload_pt_weights, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), + ) + + def compile( + self, + compile_dir, + compile_only, + specializations, + convert_to_fp16, + mxfp6_matmul, + mdp_ts_num_devices, + aic_num_cores, + custom_io, + use_onnx_subfunctions: bool = False, + **compiler_options, + ) -> str: + """ + Compiles the vision encoder component to a QPC package. + + Parameters + ---------- + compile_dir : str + Directory to save the generated QPC package. + compile_only : bool + If True, only compilation occurs without running inference. + specializations : List[Dict[str, Union[int, str]]] + List of dictionaries, each specifying a compilation specialization. + convert_to_fp16 : bool + If True, converts model to FP16 precision during compilation. + mxfp6_matmul : bool + If True, uses MXFP6 compression for MatMul weights. + mdp_ts_num_devices : int + Number of devices for multi-device (tensor slicing) compilation. + aic_num_cores : int + Number of cores to use for compilation. + custom_io : Dict[str, str] + Custom I/O configurations for the compiler. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + **compiler_options : + Additional compiler options passed to the underlying compilation command. + + Returns + ------- + str + Path to the compiled QPC package for the vision encoder. + """ + return self._compile( + compile_dir=compile_dir, + compile_only=compile_only, + specializations=specializations, + convert_to_fp16=convert_to_fp16, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=mdp_ts_num_devices, + aic_num_cores=aic_num_cores, + custom_io=custom_io, + use_onnx_subfunctions=use_onnx_subfunctions, + **compiler_options, + ) + + @property + def get_model_config(self) -> dict: + """ + Get the configuration dictionary of the underlying HuggingFace vision model. + + Returns + ------- + dict + The configuration dictionary. + """ + if hasattr(self.model.model, "audio_model"): + return self.model.model.audio_model.config.__dict__ + return self.model.model.config.__dict__ + + class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): """ QEfficient wrapper for the Causal Language Model (decoder) component of a Text-to-Image-to-Text model. @@ -1089,105 +1192,898 @@ def export( use_onnx_subfunctions: bool, optional whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False - Returns - ------- - str - Path to the generated ONNX graph file for the language decoder. - """ - if prefill_only: - assert prefill_seq_len > 1 - if not enable_chunking and self.continuous_batching: - raise NotImplementedError( - "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" - ) - self.hash_params["prefill_only"] = True - self.__update_prefill_transform(enable=True, enable_chunking=enable_chunking) - else: - self.hash_params["prefill_only"] = False - self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False)) + Returns + ------- + str + Path to the generated ONNX graph file for the language decoder. + """ + if prefill_only: + assert prefill_seq_len > 1 + if not enable_chunking and self.continuous_batching: + raise NotImplementedError( + "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" + ) + self.hash_params["prefill_only"] = True + self.__update_prefill_transform(enable=True, enable_chunking=enable_chunking) + else: + self.hash_params["prefill_only"] = False + self.__update_prefill_transform(False, retain_full_kv=kwargs.get("retain_full_kv", False)) + + return self._export( + inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + export_dir=export_dir, + offload_pt_weights=offload_pt_weights, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), + ) + + def compile( + self, + compile_dir, + compile_only, + specializations, + convert_to_fp16, + mxfp6_matmul, + mdp_ts_num_devices, + aic_num_cores, + custom_io, + use_onnx_subfunctions: bool = False, + **compiler_options, + ) -> str: + """ + Compiles the language decoder component to a QPC package. + + Parameters + ---------- + compile_dir : str + Directory to save the generated QPC package. + compile_only : bool + If True, only compilation occurs without running inference. + specializations : List[Dict[str, Union[int, str]]] + List of dictionaries, each specifying a compilation specialization. + convert_to_fp16 : bool + If True, converts model to FP16 precision during compilation. + mxfp6_matmul : bool + If True, uses MXFP6 compression for MatMul weights. + mdp_ts_num_devices : int + Number of devices for multi-device (tensor slicing) compilation. + aic_num_cores : int + Number of cores to use for compilation. + custom_io : Dict[str, str] + Custom I/O configurations for the compiler. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + **compiler_options : + Additional compiler options passed to the underlying compilation command. + + Returns + ------- + str + Path to the compiled QPC package for the language decoder. + """ + return self._compile( + compile_dir=compile_dir, + compile_only=compile_only, + specializations=specializations, + convert_to_fp16=convert_to_fp16, + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=mdp_ts_num_devices, + aic_num_cores=aic_num_cores, + custom_io=custom_io, + use_onnx_subfunctions=use_onnx_subfunctions, + **compiler_options, + ) + + @property + def get_model_config(self) -> dict: + """ + Get the configuration dictionary of the underlying HuggingFace language model. + + Returns + ------- + dict + The configuration dictionary. + """ + if hasattr(self.model, "language_model"): + return self.model.language_model.config.__dict__ + return self.model.config.__dict__ + + +class _QEffAutoModelForImageTextToTextDualQPC: + """ + Internal class handling multimodal image-text-to-text models using a dual QPC approach. + + In this approach, the vision encoder and language model decoder are compiled + into separate QPC packages. The vision encoder's KV cache might be offloaded + to CPU or managed differently from the language model's KV cache. + """ + + _hf_auto_class = AutoModelForImageTextToText + + def __init__( + self, + model: nn.Module, + continuous_batching: bool = False, + qaic_config: Optional[dict] = None, + **kwargs, + ): + """ + Initializes the dual QPC multimodal model wrapper. + + Parameters + ---------- + model : nn.Module + The full HuggingFace multimodal model. + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. + **kwargs : + Additional keyword arguments. + """ + if kwargs.pop("full_batch_size", None): + continuous_batching = True + warnings.warn( + "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 + ) + self.model = model + self.config = model.config + + self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) + self.lang_model = QEffCausalLMForTextImageToTextModel(model, qaic_config=qaic_config, **kwargs) + self.continuous_batching = continuous_batching + self.ccl_enabled = False + if qaic_config: + self.ccl_enabled = qaic_config.get("ccl_enabled", False) + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None + self.input_shapes, self.output_names = None, None + # ---Sampling--- + # Note: SamplerTransform should be applied after all other transforms + # are done. The role of the sampler is to just add nodes at the output of the + # previous transform function. + self.lang_model.model, _ = SamplerTransform.apply(self.lang_model.model, qaic_config, **kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Optional[dict] = None, **kwargs): + """ + Load a QEfficient multimodal model for dual QPC from a pretrained HuggingFace model or local path. + + Parameters + ---------- + pretrained_model_name_or_path : str + Model card name from HuggingFace or local path to model directory. + **kwargs : + Additional keyword arguments passed directly to `cls._hf_auto_class.from_pretrained`. + Note: `attn_implementation` and `low_cpu_mem_usage` are automatically + set to "eager" and False respectively to ensure compatibility. + + Returns + ------- + _QEffAutoModelForImageTextToTextDualQPC + An instance initialized with the pretrained weights. + """ + enable_proxy = kwargs.pop("enable_proxy", False) + + if kwargs.get("attn_implementation", None) not in {None, "eager"}: + logger.warning('Updating attn_implementation="eager"') + + if kwargs.get("low_cpu_mem_usage", None): + logger.warning("Updating low_cpu_mem_usage=False") + + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + + kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) + + return cls( + model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + **kwargs, + ) + + @property + def onnx_path(self): + """ + Get the ONNX paths for the vision and language model components. + + Returns + ------- + List[str] + A list containing the ONNX paths of the vision model and the language model. + """ + return [self.vision_model.onnx_path, self.lang_model.onnx_path] + + def export( + self, + export_dir: Optional[str] = None, + use_onnx_subfunctions: bool = False, + skip_vision: Optional[bool] = False, + skip_lang: Optional[bool] = False, + prefill_seq_len: Optional[int] = None, + prefill_only: bool = False, + enable_chunking: bool = False, + **kwargs, + ) -> str: + """ + Exports both the vision encoder and language decoder components to ONNX format. + + This method exports the vision component (optionally without offloading PyTorch weights) + and the language component (with offloading PyTorch weights). + + Parameters + ---------- + export_dir : str, optional + Directory path where the exported ONNX graphs will be saved. Default is None. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + **kwargs : + Additional keyword arguments. + + Returns + ------- + List[str] + A list containing the paths to the generated ONNX graph files for both components. + """ + # TODO This is a temporary change as continous batching is enabled only for few models. Once support is added for all the models this exception handing can be removed. + try: + inputs = self.model.get_dummy_inputs( + kv_offload=True, + continuous_batching=self.continuous_batching, + comp_ctx_lengths=self.comp_ctx_lengths_decode, + ) + dynamic_axes = self.model.get_onnx_dynamic_axes( + kv_offload=True, + continuous_batching=self.continuous_batching, + comp_ctx_lengths=self.comp_ctx_lengths_decode, + ) + except TypeError: + inputs = self.model.get_dummy_inputs(kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode) + dynamic_axes = self.model.get_onnx_dynamic_axes( + kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode + ) + output_names = self.model.get_output_names(kv_offload=True) + if self.lang_model.model.qaic_config is not None and self.lang_model.model.qaic_config.get( + "include_sampler", False + ): + logits_index = output_names["lang"].index("logits") + output_names["lang"][logits_index] = "next_tokens" + inputs["lang"], output_names["lang"], dynamic_axes["lang"] = get_sampling_inputs_and_outputs( + example_inputs=inputs["lang"], + output_names=output_names["lang"], + dynamic_axes=dynamic_axes["lang"], + continuous_batching=self.continuous_batching, + vocab_size=self.model.language_model.config.vocab_size, + qaic_config=self.lang_model.model.qaic_config, + ) + if not skip_vision: + self.vision_model.export( + inputs["vision"], + output_names["vision"], + dynamic_axes["vision"], + export_dir=export_dir, + offload_pt_weights=False, + use_onnx_subfunctions=use_onnx_subfunctions, + ) + + if prefill_only and prefill_seq_len > 1: + offload_pt_weights = False # to keep weight for decode onnx + else: + offload_pt_weights = kwargs.get("offload_pt_weights", True) + + if not skip_lang: + self.lang_model.export( + inputs["lang"], + output_names["lang"], + dynamic_axes["lang"], + export_dir=export_dir, + offload_pt_weights=offload_pt_weights, + use_onnx_subfunctions=use_onnx_subfunctions, + prefill_only=prefill_only, + enable_chunking=enable_chunking, + prefill_seq_len=prefill_seq_len, + ) + return self.onnx_path + + def transform( + self, + ctx_len: Optional[int] = None, + seq_len: Optional[int] = None, + bs: Optional[int] = 1, + num_devices: int = 1, + qaic_config: Optional[dict] = None, + **compiler_options, + ): + self.vision_model.transform( + ctx_len=ctx_len, + seq_len=seq_len, + bs=bs, + num_devices=num_devices, + qaic_config=qaic_config, + **compiler_options, + ) + if self.audio_model is not None: + self.audio_model.transform( + ctx_len=ctx_len, + seq_len=seq_len, + bs=bs, + num_devices=num_devices, + qaic_config=qaic_config, + **compiler_options, + ) + + self.lang_model.transform( + ctx_len=ctx_len, + seq_len=seq_len, + bs=bs, + num_devices=num_devices, + qaic_config=qaic_config, + **compiler_options, + ) + + def compile( + self, + img_size: Optional[int] = None, + vision_onnx_path: Optional[str] = None, + lang_onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + *, + prefill_seq_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, + ctx_len: Optional[int] = None, + batch_size: int = 1, + full_batch_size: Optional[int] = None, + kv_cache_batch_size: Optional[int] = None, + num_devices: int = 1, + num_cores: int = 16, # FIXME: Make this mandatory arg + mxfp6_matmul: bool = False, + mxint8_kv_cache: bool = False, + skip_vision: Optional[bool] = False, + skip_lang: Optional[bool] = False, + use_onnx_subfunctions: bool = False, + prefill_only=None, + enable_chunking=False, + qaic_config: Optional[dict] = None, + **compiler_options, + ) -> str: + """ + Compiles both the vision encoder and language decoder components into QPC packages. + + Parameters + ---------- + img_size : int, optional + The image size to compile the vision model for. Default is None. + vision_onnx_path : str, optional + Path to a pre-exported ONNX file for the vision encoder. If None, it will be exported. + lang_onnx_path : str, optional + Path to a pre-exported ONNX file for the language decoder. If None, it will be exported. + compile_dir : str, optional + Directory to save the generated QPC packages. + prefill_seq_len : int, optional + Length of the prefill prompt for the language model. Default is None. + ctx_len : int, optional + Maximum context length for the language model. Default is None. + batch_size : int, optional + Batch size. Default is 1. + full_batch_size : int, optional + Not supported for this model; must be None. + kv_cache_batch_size : int, optional + Not supported for this model; must be None. + num_devices : int, optional + Number of devices to compile for. Default is 1. + num_cores : int, optional + Number of cores to use for compilation. + mxfp6_matmul : bool, optional + Use MXFP6 compression for weights in the language model. Default is False. + mxint8_kv_cache : bool, optional + Use MXINT8 compression for KV cache. Default is False. + num_speculative_tokens : int, optional + Not supported for this model; must be None. + skip_vision : bool, optional + If True, skips compilation of the vision encoder. Default is False. + skip_lang : bool, optional + If True, skips compilation of the language decoder. Default is False. + use_onnx_subfunctions: bool, optional + whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False + **compiler_options : dict + Additional compiler options for QAIC or QNN compilers. + + Returns + ------- + Union[List[str], str, None] + A list of paths to the compiled QPC packages, or a single path if only + one component is compiled, or None if neither is compiled. + + Raises + ------ + ValueError + If `full_batch_size`, `kv_cache_batch_size`, or `num_speculative_tokens` are not None. + If both `skip_lang` and `skip_vision` are True. + """ + if skip_lang and skip_vision: + raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + + if self.continuous_batching and full_batch_size is None: + raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") + + if kv_cache_batch_size and not full_batch_size: + raise ValueError( + "KV caching requires continuous batching. Please set `full_batch_size` and " + "enable `continuous_batching=True` in `from_pretrained`." + ) + + # Infer kv_cache_batch_size if not provided + kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size + + output_names = self.model.get_output_names(kv_offload=True) + + # if ccl_enabled is True read Compute-Context-Length lists + if self.ccl_enabled: + if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None: + logger.info("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).") + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len + ) + # For supporting VLLM and Disaggregated with CCL + elif comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None: + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len + ) + + # Apply compile-dependent transforms like blocking transform + self.transform( + ctx_len=ctx_len, + seq_len=prefill_seq_len, + batch_size=batch_size, + num_devices=num_devices, + qaic_config=qaic_config, + aic_num_cores=num_cores, + ) + + specializations, compiler_options = self.model.get_specializations( + batch_size=batch_size, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, + img_size=img_size, + kv_offload=True, + continuous_batching=self.continuous_batching, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + **compiler_options, + ) + + custom_io_vision = {} + target_dtype = getattr(self.model.config, "torch_dtype", torch.float32) + kv_cache_dtype = "mxint8" if mxint8_kv_cache else CUSTOM_IO_DTYPE_MAP[target_dtype] + molmo = hasattr(self.model.config, "model_type") and self.model.config.model_type == "molmo" + if molmo: + custom_io_vision["image_masks"] = CUSTOM_IO_DTYPE_MAP[target_dtype] + custom_io_vision["pixel_values"] = CUSTOM_IO_DTYPE_MAP[target_dtype] + + for output_name in output_names["vision"]: + if output_name.startswith("past_"): + custom_io_vision[output_name] = kv_cache_dtype + else: + custom_io_vision[output_name] = CUSTOM_IO_DTYPE_MAP[target_dtype] + + if vision_onnx_path: + self.vision_model.onnx_path = vision_onnx_path + if lang_onnx_path: + self.lang_model.onnx_path = lang_onnx_path + + if vision_onnx_path is None or lang_onnx_path is None: + self.export( + use_onnx_subfunctions=use_onnx_subfunctions, + skip_vision=skip_vision, + skip_lang=skip_lang, + prefill_only=prefill_only, + enable_chunking=enable_chunking, + prefill_seq_len=prefill_seq_len, + ) + + # TODO this hould be removed once the continous batching is supported for all the models. + compiler_options.pop("continuous_batching", None) + compiler_options.pop("kv_cache_batch_size", None) + compiler_options.pop("full_batch_size", None) + self.qpc_paths = {} + if not skip_vision: + vision_qpc_path = self.vision_model._compile( + compile_dir=compile_dir, + compile_only=True, + specializations=specializations["vision"], + convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[target_dtype] == "float16"), + mxfp6_matmul=constants.VISION_MXFP6_MATMUL, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + custom_io=custom_io_vision, + mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, + **compiler_options, + ) + self.qpc_paths["vision_qpc_path"] = vision_qpc_path + + # Custom NPI file options + if hasattr(self.model, "get_npi_file") and "node_precision_info" not in compiler_options: + compiler_options["node_precision_info"] = self.model.get_npi_file(self.model.name_or_path) + + if not skip_lang: + custom_io_lang = {} + # Inputs + for output_name in output_names["lang"]: + if output_name.endswith("_RetainedState"): + custom_io_lang[output_name[: -len("_RetainedState")]] = ( + CUSTOM_IO_DTYPE_MAP[target_dtype] + if ("vision_embeds" in output_name or "deepstack_features" in output_name) + else kv_cache_dtype + ) + + # outputs + for output_name in output_names["lang"]: + if output_name.endswith("_RetainedState"): + custom_io_lang[output_name] = ( + CUSTOM_IO_DTYPE_MAP[target_dtype] + if ("vision_embeds" in output_name or "deepstack_features" in output_name) + else kv_cache_dtype + ) + if prefill_only: + specializations = specializations["lang"][:1] + qpc_key = "lang_prefill_qpc_path" + elif prefill_seq_len == 1: + specializations = specializations["lang"][-1:] + qpc_key = "lang_decode_qpc_path" + else: + specializations = specializations["lang"] + qpc_key = "lang_qpc_path" + + lang_qpc_path = self.lang_model._compile( + compile_dir=compile_dir, + compile_only=True, + retained_state=True, + specializations=specializations, + convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[target_dtype] == "float16"), + mxfp6_matmul=mxfp6_matmul, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + custom_io=custom_io_lang, + mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, + **compiler_options, + ) + self.qpc_paths.update({qpc_key: lang_qpc_path}) + return self.qpc_paths + + def generate( + self, + inputs: Optional[torch.Tensor] = None, + tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer] = None, + processor: Optional[AutoImageProcessor] = None, + images: List[str] = None, + prompts: List[str] = None, + streamer: Optional[TextStreamer] = None, + device_ids: List[int] = None, + runtime_ai100: bool = True, + generation_len: Optional[int] = None, + image_height: Optional[int] = None, + image_width: Optional[int] = None, + **kwargs, + ) -> Union[torch.Tensor, np.ndarray]: + """ + Generates output by executing the compiled QPC(s) on Cloud AI 100 Hardware cards. + + This method coordinates inference between the vision encoder and language model decoder. + + Parameters + ---------- + inputs : Dict[str, Union[torch.Tensor, np.ndarray]] + Inputs to run the execution, typically includes `pixel_values`, `input_ids`, + `attention_mask`, etc. + tokenizer : PreTrainedTokenizer or PreTrainedTokenizerFast, optional + Tokenizer for the model. Used when images and prompts are provided. + processor : AutoImageProcessor, optional + Processor for the model. Used when images and prompts are provided. + images : List[str], optional + List of image paths or PIL images to process. + prompts : List[str], optional + List of text prompts corresponding to the images. + streamer : TextStreamer, optional + A streamer object to display generated tokens in real-time. Default is None. + device_ids : List[int], optional + IDs of devices for running the QPC. E.g., `[0]` for a single device or + `[0, 1, 2, 3]` for tensor slicing. Defaults to `[0]` if not specified. + runtime_ai100 : bool, optional + If True, uses the AI 100 runtime. PyTorch runtime is not supported for this model. + Default is True. + generation_len : int, optional + The maximum number of tokens to generate. If None, it's inferred from `ctx_len`. + + Returns + ------- + CloudAI100ExecInfoNew or np.ndarray + Output from the AI 100 runtime, including generated IDs and performance metrics. + + Raises + ------ + NotImplementedError + If `runtime_ai100` is False. + """ + if not runtime_ai100: + raise NotImplementedError("PyTorch execution is not supported yet for this model!") + + write_io = kwargs.pop("write_io", False) + self._write_io_dir = os.path.join(os.path.dirname(self.onnx_path[1]), "io_dir") if write_io else None + + # Use VisionLanguageGeneration for image-prompt pairs + if (processor and images) or (tokenizer and prompts): + # Create VisionLanguageGeneration instance + batch_size_comp, ctx_len_comp, fbs = get_compilation_dims(self.lang_model.qpc_path) + vlm_gen = VisionLanguageGeneration( + qeff_model=self, + lang_qpc_path=self.lang_model.qpc_path, + vision_qpc_path=self.vision_model.qpc_path, + tokenizer=tokenizer, + processor=processor, + device_id=device_ids, # if device_ids is not None else [0], + ctx_len=ctx_len_comp, + full_batch_size=fbs, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, + image_height=image_height, + image_width=image_width, + write_io_dir=self._write_io_dir, + **kwargs, + ) + + # Call generate method + return vlm_gen.generate( + images=images, + prompts=prompts, + generation_len=generation_len, + stream=streamer is not None, + ) + + # Fallback to kv_offload_generate for direct inputs (backward compatibility) + return self.kv_offload_generate( + inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len + ) + + def kv_offload_generate( + self, + inputs: List[str] = None, + streamer: Optional[TextStreamer] = None, + device_ids: List[int] = None, + generation_len: int = None, + ): + """ + Performs generation for multimodal models with KV offloading to CPU. + + This method orchestrates the inference by running the vision encoder (if compiled) + and then iteratively running the language decoder, managing KV cache states. + + Parameters + ---------- + inputs : Dict[str, Union[torch.Tensor, np.ndarray]] + Input tensors for the multimodal model. + streamer : TextStreamer, optional + A streamer object to display generated tokens in real-time. Default is None. + device_ids : List[int], optional + IDs of devices for running the QPC. Defaults to `[0]` if not specified. + generation_len : int, optional + The maximum number of tokens to generate. If None, it's inferred from `ctx_len`. + + Returns + ------- + CloudAI100ExecInfoNew + Execution information including generated IDs and performance metrics. + + Raises + ------ + TypeError + If the language model QPC is not compiled. + AssertionError + If `generation_len` is not greater than zero. + """ + if not self.lang_model.qpc_path: + raise TypeError("Please run compile API for language model first!") + + lang_session = QAICInferenceSession(self.lang_model.qpc_path, device_ids, activate=False) + + if self.vision_model.qpc_path: + vision_session = QAICInferenceSession(self.vision_model.qpc_path, device_ids) + + batch_size, ctx_len, fbs = get_compilation_dims(self.lang_model.qpc_path) + + pad_token_id = 1 + + # Skip inputs/outputs + lang_session.skip_buffers( + [ + x + for x in lang_session.input_names + lang_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + ) + + # Read prompt and ctx len from session + batch_size = max( + [x[lang_session.binding_index_map["input_ids"]][1][0] for x in lang_session.allowed_shapes] + + [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[0]] + ) + + prefill_seq_len = max( + [x[lang_session.binding_index_map["input_ids"]][1][1] for x in lang_session.allowed_shapes] + + [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[1]] + ) + input_len = inputs["attention_mask"].sum(1, keepdims=True) + input_ids_length = inputs["input_ids"].shape[1] + num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + + if generation_len is None: + generation_len = ctx_len - input_len.max() + assert generation_len > 0, "generation length should be greater than zero" + generated_ids = np.full((batch_size, generation_len + 1), pad_token_id) + + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], + (0, padded_len - input_ids_length), + "constant", + pad_token_id, + ) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 + ) + if "cross_attention_mask" in inputs: + inputs["cross_attention_mask"] = torch.nn.functional.pad( + inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length) + ) + + for k, v in inputs.items(): + inputs[k] = np.array(v) + + vision_inputs = { + k: v + for k, v in inputs.items() + if k + in {"pixel_values", "image_masks", "image_input_idx", "valid_idx", "aspect_ratio_ids", "aspect_ratio_mask"} + } + + vision_inputs_fp16 = {"pixel_values", "image_masks"} + vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs}) + + vision_start = perf_counter() + + vision_outputs = {} + if vision_inputs: + vision_outputs = vision_session.run(vision_inputs) + vision_end = perf_counter() + + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + if "position_ids" in inputs: + lang_inputs["position_ids"] = inputs["position_ids"] + lang_inputs.pop("attention_mask") + else: + lang_inputs["position_ids"] = np.where( + lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 + ) # Need to use -1 as position_ids for invalid tokens + + not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama" + if not_mllama: + lang_inputs["image_idx"] = np.array([[0]]) + if self.audio_model is not None: + lang_inputs["audio_idx"] = np.array([[0]]) + if self.vision_model.qpc_path: + vision_session.deactivate() + lang_session.activate() + + lang_session.set_buffers(vision_outputs) + + if self.comp_ctx_lengths_prefill is not None: + list_of_comp_ctx_lengths_prefill = [ + np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_prefill + ] + prefill_ccl_id = 0 + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + + lang_start = perf_counter() + # Run prefill + chunk_inputs = lang_inputs.copy() + for i in range(num_chunks): + if ( + self.comp_ctx_lengths_prefill is not None + and (i + 1) * prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id] + ): + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + chunk_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + + chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] + chunk_inputs["position_ids"] = lang_inputs["position_ids"][ + ..., i * prefill_seq_len : (i + 1) * prefill_seq_len + ] + outputs = lang_session.run(chunk_inputs) + chunk_inputs["image_idx"] = outputs["image_idx_output"] + if "audio_idx_output" in outputs: + chunk_inputs["audio_idx"] = outputs["audio_idx_output"] + + if self._write_io_dir is not None: + write_io_files(lang_inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) - return self._export( - inputs, - output_names=output_names, - dynamic_axes=dynamic_axes, - export_dir=export_dir, - offload_pt_weights=offload_pt_weights, - use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), + prefill_time = perf_counter() - lang_start + vision_end - vision_start + # Skip inputs/outputs again + lang_session.skip_buffers( + [ + x + for x in lang_session.input_names + lang_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] ) + if not_mllama: + lang_session.skip_buffers(vision_outputs.keys()) + # Get first token + lang_inputs["input_ids"] = outputs["logits"].argmax(2) + lang_inputs["position_ids"] = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 + if "cross_attention_mask" in lang_inputs: + bs, _, num_images, img_tiles = lang_inputs["cross_attention_mask"].shape + lang_inputs["cross_attention_mask"] = torch.ones((bs, 1, num_images, img_tiles), dtype=torch.int64).numpy() + generated_ids[:, 0] = lang_inputs["input_ids"].squeeze(1) - def compile( - self, - compile_dir, - compile_only, - specializations, - convert_to_fp16, - mxfp6_matmul, - mdp_ts_num_devices, - aic_num_cores, - custom_io, - use_onnx_subfunctions: bool = False, - **compiler_options, - ) -> str: - """ - Compiles the language decoder component to a QPC package. + if streamer: + streamer.put(lang_inputs["input_ids"][0]) - Parameters - ---------- - compile_dir : str - Directory to save the generated QPC package. - compile_only : bool - If True, only compilation occurs without running inference. - specializations : List[Dict[str, Union[int, str]]] - List of dictionaries, each specifying a compilation specialization. - convert_to_fp16 : bool - If True, converts model to FP16 precision during compilation. - mxfp6_matmul : bool - If True, uses MXFP6 compression for MatMul weights. - mdp_ts_num_devices : int - Number of devices for multi-device (tensor slicing) compilation. - aic_num_cores : int - Number of cores to use for compilation. - custom_io : Dict[str, str] - Custom I/O configurations for the compiler. - use_onnx_subfunctions: bool, optional - whether to enable ONNX subfunctions during export. Exporting PyTorch model to ONNX with modules as subfunctions helps to reduce export/compile time. Defaults to False - **compiler_options : - Additional compiler options passed to the underlying compilation command. + # Decode loop + if self.comp_ctx_lengths_decode is not None: + max_ccl_id = len(self.comp_ctx_lengths_decode) - 1 + list_of_comp_ctx_lengths_decode = [ + np.zeros(length, dtype=np.int8) for length in self.comp_ctx_lengths_decode + ] + max_position_id = np.max(lang_inputs["position_ids"]) + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] - Returns - ------- - str - Path to the compiled QPC package for the language decoder. - """ - return self._compile( - compile_dir=compile_dir, - compile_only=compile_only, - specializations=specializations, - convert_to_fp16=convert_to_fp16, - mxfp6_matmul=mxfp6_matmul, - mdp_ts_num_devices=mdp_ts_num_devices, - aic_num_cores=aic_num_cores, - custom_io=custom_io, - use_onnx_subfunctions=use_onnx_subfunctions, - **compiler_options, - ) + decode_start = perf_counter() + for num_token in range(1, generation_len): + if self.comp_ctx_lengths_decode is not None: + if max_position_id >= self.comp_ctx_lengths_decode[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] - @property - def get_model_config(self) -> dict: - """ - Get the configuration dictionary of the underlying HuggingFace language model. + outputs = lang_session.run(lang_inputs) + if self._write_io_dir is not None: + write_io_files(lang_inputs, outputs, self._write_io_dir, "decode", "aic_batch_io", True, False) + self._write_io_dir = None - Returns - ------- - dict - The configuration dictionary. - """ - if hasattr(self.model, "language_model"): - return self.model.language_model.config.__dict__ - return self.model.config.__dict__ + # Prepare inputs for next iteration + lang_inputs["input_ids"] = outputs["logits"].argmax(2) + lang_inputs["position_ids"] += 1 + generated_ids[:, num_token] = lang_inputs["input_ids"].squeeze(1) + if streamer: + streamer.put(lang_inputs["input_ids"][0]) + decode_end = perf_counter() + if streamer: + streamer.end() -class _QEffAutoModelForImageTextToTextDualQPC: + decode_perf = (num_token - 1) / (decode_end - decode_start) + total_time = decode_end - decode_start + prefill_time + total_perf = num_token / total_time + + return CloudAI100ExecInfoNew( + batch_size=batch_size, + generated_ids=generated_ids, + perf_metrics=PerfMetrics( + prefill_time=prefill_time, decode_perf=decode_perf, total_perf=total_perf, total_time=total_time + ), + ) + + +class _QEFFAutoModelForMultimodalLMMultiQPC: """ Internal class handling multimodal image-text-to-text models using a dual QPC approach. @@ -1196,7 +2092,7 @@ class _QEffAutoModelForImageTextToTextDualQPC: to CPU or managed differently from the language model's KV cache. """ - _hf_auto_class = AutoModelForImageTextToText + _hf_auto_class = AutoModelForMultimodalLM def __init__( self, @@ -1226,6 +2122,8 @@ def __init__( self.config = model.config self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) + self.audio_model = QEffAudioEncoderForMultimodalLM(model, **kwargs) + self.lang_model = QEffCausalLMForTextImageToTextModel(model, qaic_config=qaic_config, **kwargs) self.continuous_batching = continuous_batching self.ccl_enabled = False @@ -1289,13 +2187,17 @@ def onnx_path(self): List[str] A list containing the ONNX paths of the vision model and the language model. """ - return [self.vision_model.onnx_path, self.lang_model.onnx_path] + onnx_paths = [self.vision_model.onnx_path, self.lang_model.onnx_path] + if self.audio_model is not None: + onnx_paths.append(self.audio_model.onnx_path) + return onnx_paths def export( self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False, skip_vision: Optional[bool] = False, + skip_audio: Optional[bool] = False, skip_lang: Optional[bool] = False, prefill_seq_len: Optional[int] = None, prefill_only: bool = False, @@ -1362,6 +2264,21 @@ def export( offload_pt_weights=False, use_onnx_subfunctions=use_onnx_subfunctions, ) + if ( + self.audio_model is not None + and "audio" in inputs + and "audio" in output_names + and "audio" in dynamic_axes + and not skip_audio + ): + self.audio_model.export( + inputs["audio"], + output_names["audio"], + dynamic_axes["audio"], + export_dir=export_dir, + offload_pt_weights=False, + use_onnx_subfunctions=use_onnx_subfunctions, + ) if prefill_only and prefill_seq_len > 1: offload_pt_weights = False # to keep weight for decode onnx @@ -1413,6 +2330,7 @@ def compile( self, img_size: Optional[int] = None, vision_onnx_path: Optional[str] = None, + audio_onnx_path: Optional[str] = None, lang_onnx_path: Optional[str] = None, compile_dir: Optional[str] = None, *, @@ -1428,6 +2346,7 @@ def compile( mxfp6_matmul: bool = False, mxint8_kv_cache: bool = False, skip_vision: Optional[bool] = False, + skip_audio: Optional[bool] = False, skip_lang: Optional[bool] = False, use_onnx_subfunctions: bool = False, prefill_only=None, @@ -1489,8 +2408,8 @@ def compile( If `full_batch_size`, `kv_cache_batch_size`, or `num_speculative_tokens` are not None. If both `skip_lang` and `skip_vision` are True. """ - if skip_lang and skip_vision: - raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + if skip_lang and skip_vision and (skip_audio or self.audio_model is None): + raise ValueError("Expected at least one of 'skip_lang', 'skip_vision', or 'skip_audio' to be False") if self.continuous_batching and full_batch_size is None: raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") @@ -1559,13 +2478,20 @@ def compile( if vision_onnx_path: self.vision_model.onnx_path = vision_onnx_path + if audio_onnx_path and self.audio_model is not None: + self.audio_model.onnx_path = audio_onnx_path if lang_onnx_path: self.lang_model.onnx_path = lang_onnx_path - if vision_onnx_path is None or lang_onnx_path is None: + if ( + vision_onnx_path is None + or lang_onnx_path is None + or (self.audio_model is not None and audio_onnx_path is None and not skip_audio) + ): self.export( use_onnx_subfunctions=use_onnx_subfunctions, skip_vision=skip_vision, + skip_audio=skip_audio, skip_lang=skip_lang, prefill_only=prefill_only, enable_chunking=enable_chunking, @@ -1593,6 +2519,43 @@ def compile( ) self.qpc_paths["vision_qpc_path"] = vision_qpc_path + if self.audio_model is not None and not skip_audio: + custom_io_audio = {} + try: + audio_example_inputs = self.model.get_dummy_inputs( + kv_offload=True, + continuous_batching=self.continuous_batching, + comp_ctx_lengths=self.comp_ctx_lengths_decode, + ).get("audio", {}) + except TypeError: + audio_example_inputs = self.model.get_dummy_inputs( + kv_offload=True, + comp_ctx_lengths=self.comp_ctx_lengths_decode, + ).get("audio", {}) + for input_name, input_value in audio_example_inputs.items(): + if isinstance(input_value, torch.Tensor) and input_value.dtype.is_floating_point: + custom_io_audio[input_name] = CUSTOM_IO_DTYPE_MAP[target_dtype] + for output_name in output_names.get("audio", []): + custom_io_audio[output_name] = CUSTOM_IO_DTYPE_MAP[target_dtype] + + audio_specializations = specializations.get("audio") + if audio_specializations is None: + audio_specializations = specializations["vision"] + audio_qpc_path = self.audio_model._compile( + compile_dir=compile_dir, + compile_only=True, + specializations=audio_specializations, + convert_to_fp16=(CUSTOM_IO_DTYPE_MAP[target_dtype] == "float16"), + mxfp6_matmul=constants.VISION_MXFP6_MATMUL, + mdp_ts_num_devices=num_devices, + aic_num_cores=num_cores, + custom_io=custom_io_audio, + mxint8_kv_cache=mxint8_kv_cache, + use_onnx_subfunctions=use_onnx_subfunctions, + **compiler_options, + ) + self.qpc_paths["audio_qpc_path"] = audio_qpc_path + # Custom NPI file options if hasattr(self.model, "get_npi_file") and "node_precision_info" not in compiler_options: compiler_options["node_precision_info"] = self.model.get_npi_file(self.model.name_or_path) @@ -1604,7 +2567,11 @@ def compile( if output_name.endswith("_RetainedState"): custom_io_lang[output_name[: -len("_RetainedState")]] = ( CUSTOM_IO_DTYPE_MAP[target_dtype] - if ("vision_embeds" in output_name or "deepstack_features" in output_name) + if ( + "vision_embeds" in output_name + or "deepstack_features" in output_name + or "audio_embeds" in output_name + ) else kv_cache_dtype ) @@ -1613,7 +2580,11 @@ def compile( if output_name.endswith("_RetainedState"): custom_io_lang[output_name] = ( CUSTOM_IO_DTYPE_MAP[target_dtype] - if ("vision_embeds" in output_name or "deepstack_features" in output_name) + if ( + "vision_embeds" in output_name + or "deepstack_features" in output_name + or "audio_embeds" in output_name + ) else kv_cache_dtype ) if prefill_only: @@ -1780,6 +2751,8 @@ def kv_offload_generate( if self.vision_model.qpc_path: vision_session = QAICInferenceSession(self.vision_model.qpc_path, device_ids) + if self.audio_model is not None and self.audio_model.qpc_path: + audio_session = QAICInferenceSession(self.audio_model.qpc_path, device_ids) batch_size, ctx_len, fbs = get_compilation_dims(self.lang_model.qpc_path) @@ -1837,18 +2810,25 @@ def kv_offload_generate( if k in {"pixel_values", "image_masks", "image_input_idx", "valid_idx", "aspect_ratio_ids", "aspect_ratio_mask"} } + audio_inputs = {} + if self.audio_model is not None and self.audio_model.qpc_path: + audio_inputs = {k: v for k, v in inputs.items() if k in set(audio_session.input_names)} vision_inputs_fp16 = {"pixel_values", "image_masks"} vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs}) - + audio_inputs_fp16 = {"input_features"} + audio_inputs.update({k: audio_inputs[k].astype("float16") for k in audio_inputs_fp16 if k in audio_inputs}) vision_start = perf_counter() vision_outputs = {} if vision_inputs: vision_outputs = vision_session.run(vision_inputs) + audio_outputs = {} + if audio_inputs: + audio_outputs = audio_session.run(audio_inputs) vision_end = perf_counter() - lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs and k not in audio_inputs} if "position_ids" in inputs: lang_inputs["position_ids"] = inputs["position_ids"] lang_inputs.pop("attention_mask") @@ -1862,9 +2842,11 @@ def kv_offload_generate( lang_inputs["image_idx"] = np.array([[0]]) if self.vision_model.qpc_path: vision_session.deactivate() + if self.audio_model is not None and self.audio_model.qpc_path: + audio_session.deactivate() lang_session.activate() - lang_session.set_buffers(vision_outputs) + lang_session.set_buffers({**vision_outputs, **audio_outputs}) if self.comp_ctx_lengths_prefill is not None: list_of_comp_ctx_lengths_prefill = [ @@ -1890,6 +2872,7 @@ def kv_offload_generate( ] outputs = lang_session.run(chunk_inputs) chunk_inputs["image_idx"] = outputs["image_idx_output"] + chunk_inputs["audio_idx"] = outputs["audio_idx_output"] if self._write_io_dir is not None: write_io_files(lang_inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) @@ -1904,7 +2887,7 @@ def kv_offload_generate( ] ) if not_mllama: - lang_session.skip_buffers(vision_outputs.keys()) + lang_session.skip_buffers(set(vision_outputs.keys()) | set(audio_outputs.keys())) # Get first token lang_inputs["input_ids"] = outputs["logits"].argmax(2) lang_inputs["position_ids"] = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 @@ -2709,6 +3692,72 @@ def from_pretrained( ) +class QEFFAutoModelForMultimodalLM: + """ + QEfficient class for multimodal language models from the HuggingFace hub. + + This class supports both single and dual QPC (Quantized Package Compilation) + approaches and mirrors the image-text-to-text multimodal workflow. + """ + + _hf_auto_class = AutoModelForMultimodalLM + + def __new__( + self, + model: nn.Module, + kv_offload: Optional[bool] = True, + continuous_batching: bool = False, + qaic_config: Optional[dict] = None, + **kwargs, + ): + """ + Instantiate the appropriate internal class for single or dual QPC mode. + """ + if kv_offload: + return _QEFFAutoModelForMultimodalLMMultiQPC(model, continuous_batching, qaic_config=qaic_config, **kwargs) + else: + return _QEFFAutoModelForImageTextToTextSingleQPC(model, qaic_config=qaic_config, **kwargs) + + @classmethod + @with_replaced_quantizers + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + kv_offload: Optional[bool] = None, + continuous_batching: bool = False, + qaic_config: Optional[dict] = None, + **kwargs, + ): + """ + Load a QEfficient multimodal LM from a pretrained HuggingFace model or local path. + """ + enable_proxy = kwargs.pop("enable_proxy", False) + + # TODO: add a check to see if kv_offload is allowed for given model by loading the config and checking architecture or type of config here. + if continuous_batching and not kv_offload: + NotImplementedError("Continuous batching is not supported for kv_offload = False") + + if kwargs.get("attn_implementation", None) not in {None, "eager"}: + logger.warning('Updating attn_implementation="eager"') + + if kwargs.get("low_cpu_mem_usage", None): + logger.warning("Updating low_cpu_mem_usage=False") + + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) + + kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {}) + + return cls( + model, + kv_offload=kv_offload, + continuous_batching=continuous_batching, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + **kwargs, + ) + + MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = { "InternVLChatModel": QEFFAutoModelForImageTextToText, "MolmoForCausalLM": QEFFAutoModelForImageTextToText, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index a5e16489f..2fc9d8625 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -175,7 +175,7 @@ Qwen2_5_VLVisionAttention, ) from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2RMSNorm as Qwen2_5RMSNorm, + Qwen2_5_VLRMSNorm as Qwen2_5RMSNorm, ) from transformers.models.qwen3.modeling_qwen3 import ( Qwen3Attention, @@ -193,6 +193,23 @@ Qwen3MoeRotaryEmbedding, Qwen3MoeSparseMoeBlock, ) +from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeAudioAttention, + Qwen3OmniMoeAudioEncoder, + Qwen3OmniMoeAudioEncoderLayer, + Qwen3OmniMoeForConditionalGeneration, + Qwen3OmniMoeTextRMSNorm, + Qwen3OmniMoeThinkerForConditionalGeneration, + Qwen3OmniMoeThinkerTextAttention, + Qwen3OmniMoeThinkerTextDecoderLayer, + Qwen3OmniMoeThinkerTextModel, + Qwen3OmniMoeThinkerTextRMSNorm, + Qwen3OmniMoeThinkerTextRotaryEmbedding, + Qwen3OmniMoeThinkerTextSparseMoeBlock, + Qwen3OmniMoeThinkerTextTopKRouter, + Qwen3OmniMoeVisionAttention, + Qwen3OmniMoeVisionEncoder, +) from transformers.models.qwen3_vl.modeling_qwen3_vl import ( Qwen3VLForConditionalGeneration, Qwen3VLModel, @@ -455,6 +472,21 @@ QEffQwen3MoeRotaryEmbedding, QEffQwen3MoeSparseMoeBlock, ) +from QEfficient.transformers.models.qwen3_omni.modeling_qwen3_omni import ( + QEffQwen3OmniMoeAudioAttention, + QEffQwen3OmniMoeAudioEncoder, + QEffQwen3OmniMoeAudioEncoderLayer, + QEffQwen3OmniMoeForConditionalGeneration, + QEffQwen3OmniMoeThinkerForConditionalGeneration, + QEffQwen3OmniMoeThinkerTextAttention, + QEffQwen3OmniMoeThinkerTextDecoderLayer, + QEffQwen3OmniMoeThinkerTextModel, + QEffQwen3OmniMoeThinkerTextRotaryEmbedding, + QEffQwen3OmniMoeThinkerTextSparseMoeBlock, + QEffQwen3OmniMoeThinkerTextTopKRouter, + QEffQwen3OmniMoeVisionAttention, + QEffQwen3OmniMoeVisionEncoder, +) from QEfficient.transformers.models.qwen3_vl.modeling_qwen3_vl import ( QEffQwen3VLForConditionalGeneration, QEffQwen3VLModel, @@ -526,6 +558,8 @@ class CustomOpsTransform(ModuleMappingTransform): Olmo2RMSNorm: CustomRMSNormAIC, Qwen3VLMoeTextRMSNorm: CustomRMSNormAIC, Qwen3VLTextRMSNorm: CustomRMSNormAIC, + Qwen3OmniMoeTextRMSNorm: CustomRMSNormAIC, + Qwen3OmniMoeThinkerTextRMSNorm: CustomRMSNormAIC, } @@ -694,6 +728,20 @@ class KVCacheTransform(ModuleMappingTransform): Qwen2_5_VisionTransformerPretrainedModel: QEffQwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VLVisionAttention: QEffQwen2_5_VLVisionAttention, Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel, + # Qwen3Omni + Qwen3OmniMoeAudioEncoder: QEffQwen3OmniMoeAudioEncoder, + Qwen3OmniMoeAudioEncoderLayer: QEffQwen3OmniMoeAudioEncoderLayer, + Qwen3OmniMoeAudioAttention: QEffQwen3OmniMoeAudioAttention, + Qwen3OmniMoeForConditionalGeneration: QEffQwen3OmniMoeForConditionalGeneration, + Qwen3OmniMoeThinkerForConditionalGeneration: QEffQwen3OmniMoeThinkerForConditionalGeneration, + Qwen3OmniMoeThinkerTextSparseMoeBlock: QEffQwen3OmniMoeThinkerTextSparseMoeBlock, + Qwen3OmniMoeThinkerTextModel: QEffQwen3OmniMoeThinkerTextModel, + Qwen3OmniMoeThinkerTextDecoderLayer: QEffQwen3OmniMoeThinkerTextDecoderLayer, + Qwen3OmniMoeThinkerTextAttention: QEffQwen3OmniMoeThinkerTextAttention, + Qwen3OmniMoeVisionAttention: QEffQwen3OmniMoeVisionAttention, + Qwen3OmniMoeVisionEncoder: QEffQwen3OmniMoeVisionEncoder, + Qwen3OmniMoeThinkerTextRotaryEmbedding: QEffQwen3OmniMoeThinkerTextRotaryEmbedding, + Qwen3OmniMoeThinkerTextTopKRouter: QEffQwen3OmniMoeThinkerTextTopKRouter, # Starcoder2 Starcoder2Attention: QEffStarcoder2Attention, Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer, diff --git a/QEfficient/transformers/models/qwen3_omni/__init__.py b/QEfficient/transformers/models/qwen3_omni/__init__.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/transformers/models/qwen3_omni/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/qwen3_omni/modeling_qwen3_omni.py b/QEfficient/transformers/models/qwen3_omni/modeling_qwen3_omni.py new file mode 100644 index 000000000..0c3907ea2 --- /dev/null +++ b/QEfficient/transformers/models/qwen3_omni/modeling_qwen3_omni.py @@ -0,0 +1,1633 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import Qwen3OmniMoeForConditionalGeneration +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPooling, +) +from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeAudioAttention, + Qwen3OmniMoeAudioEncoder, + Qwen3OmniMoeAudioEncoderLayer, + Qwen3OmniMoeTextConfig, + Qwen3OmniMoeThinkerCausalLMOutputWithPast, + Qwen3OmniMoeThinkerForConditionalGeneration, + Qwen3OmniMoeThinkerTextAttention, + Qwen3OmniMoeThinkerTextDecoderLayer, + Qwen3OmniMoeThinkerTextModel, + Qwen3OmniMoeThinkerTextRotaryEmbedding, + Qwen3OmniMoeThinkerTextSparseMoeBlock, + Qwen3OmniMoeThinkerTextTopKRouter, + Qwen3OmniMoeVisionAttention, + Qwen3OmniMoeVisionBlock, + Qwen3OmniMoeVisionEncoder, + _get_feat_extract_output_lengths, + apply_rotary_pos_emb_vision, + repeat_kv, + rotate_half, +) + +from QEfficient.transformers.cache_utils import QEffDynamicCache + +# from transformers import Qw +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils import constants +from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE +from QEfficient.utils.logging_utils import logger + + +def qeff_apply_interleaved_mrope(freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THWTHWTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + half_shape = freqs.shape[-1] // 2 + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + offset += half_shape + length += half_shape + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + +def qeff_pad_sequence( + tensor_list: List[torch.Tensor], batch_first: bool = False, padding_value: float = 0.0 +) -> torch.Tensor: + """ONNX-export friendly replacement for nn.utils.rnn.pad_sequence.""" + if len(tensor_list) == 0: + raise ValueError("qeff_pad_sequence expects a non-empty tensor list") + + max_len = max(t.shape[0] for t in tensor_list) + padded = [] + for tensor in tensor_list: + pad_len = max_len - tensor.shape[0] + if pad_len > 0: + pad_shape = (pad_len,) + tuple(tensor.shape[1:]) + pad_tensor = tensor.new_zeros(pad_shape) + if padding_value: + pad_tensor.fill_(padding_value) + tensor = torch.cat([tensor, pad_tensor], dim=0) + padded.append(tensor) + + output = torch.stack(padded, dim=0) + if not batch_first: + output = output.transpose(0, 1) + return output + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids] + sin = sin[position_ids] + cos = qeff_apply_interleaved_mrope(cos, mrope_section) + sin = qeff_apply_interleaved_mrope(sin, mrope_section) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +class QEffQwen3OmniMoeThinkerTextRotaryEmbedding(Qwen3OmniMoeThinkerTextRotaryEmbedding): + """ + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, config: Qwen3OmniMoeTextConfig, device=None): + super().__init__(config=config) + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +class QEffQwen3OmniMoeVisionEncoder(Qwen3OmniMoeVisionEncoder): + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + bs, num_frames, height, width = grid_thw.shape + max_hw = max(height, width) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + device = freq_table.device + + total_tokens = bs * num_frames * height * width + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) # block row indices + block_cols = torch.arange(merged_w, device=device) # block col indices + intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + coords = coords.repeat(bs * num_frames, 1) + pos_ids[:] = coords + + embeddings = freq_table[pos_ids] # lookup rotary embeddings + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + # + bs, t, h, w = grid_thw.shape + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + max_t = torch.tensor(self.num_grid_per_side - 1, device=h_idxs.device) + + h_idxs_ceil = torch.minimum(h_idxs_floor + 1, max_t) # working + w_idxs_ceil = torch.minimum(w_idxs_floor + 1, max_t) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + idx_tensor = torch.stack(indices, dim=0).to(dtype=torch.long, device=self.pos_embed.weight.device) # [4, h*w] + + weight_tensor = torch.stack(weights, dim=0).to( + dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + pos_embed = patch_pos_embeds[0] + pos_embed = pos_embed.repeat(t, 1) + + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + x_expanded = patch_pos_embeds.unsqueeze(0) + x_expanded = x_expanded.expand(bs, -1, -1) + patch_pos_embeds = x_expanded.reshape(-1, patch_pos_embeds.size(1)) + return patch_pos_embeds + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + + # return hidden_states, pos_embeds + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + bs, t, h, w = grid_thw.shape + + t = torch.arange(t, t + 1).squeeze().expand(bs) + h = torch.arange(h, h + 1).squeeze().expand(bs) + w = torch.arange(w, w + 1).squeeze().expand(bs) + + cu_seqlens = (h * w).cumsum( + dim=0, + dtype=torch.int32, + ) + cu_seqlens = torch.cat([torch.tensor([0], dtype=cu_seqlens.dtype), cu_seqlens]) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( + hidden_states + ) + deepstack_feature_lists.append(deepstack_feature) + hidden_states = self.merger(hidden_states) + return hidden_states, deepstack_feature_lists + + +class QEffQwen3OmniMoeAudioAttention(Qwen3OmniMoeAudioAttention): + """ + Copied from Qwen3OmniMoeAudioAttention with QEff wrapper class so audio attention + can be transformed the same way as other QEff module-level overrides. + """ + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + seq_length, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + if attention_mask is None and cu_seqlens is not None: + # Whisper-style eager mask: True means masked, False means keep. + attention_mask = torch.ones( + (1, 1, seq_length, seq_length), + dtype=torch.bool, + device=hidden_states.device, + ) + for i in range(1, cu_seqlens.shape[0]): + start = int(cu_seqlens[i - 1].item()) + end = int(cu_seqlens[i].item()) + attention_mask[..., start:end, start:end] = False + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask=attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + cu_seq_lens_q=cu_seqlens, # kept for signature compatibility + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.out_proj(attn_output) + return attn_output + + +class QEffQwen3OmniMoeAudioEncoderLayer(Qwen3OmniMoeAudioEncoderLayer): + """ + QEff wrapper for audio encoder layer. + Replaces attention submodule with QEff audio attention while preserving weights/behavior. + """ + + def __init__(self, config): + super().__init__(config) + old_attn_state = self.self_attn.state_dict() + self.self_attn = QEffQwen3OmniMoeAudioAttention(config) + self.self_attn.load_state_dict(old_attn_state, strict=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + clamp_value = 65504.0 - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + return outputs + + +class QEffQwen3OmniMoeAudioEncoder(Qwen3OmniMoeAudioEncoder): + def __init__(self, config): + super().__init__(config) + old_layers_state = self.layers.state_dict() + self.layers = nn.ModuleList([QEffQwen3OmniMoeAudioEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layers.load_state_dict(old_layers_state, strict=True) + + def forward( + self, + input_features, + feature_lens=None, + aftercnn_lens=None, + **kwargs, + ): + aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) + chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + + chunk_lengths = torch.full((chunk_num.sum(),), self.n_window * 2, dtype=torch.long, device=feature_lens.device) + tail_chunk_index = torch.cumsum(chunk_num, dim=0) - 1 + chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) + chunk_lengths[chunk_lengths == 0] = self.n_window * 2 + + chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + padded_feature = qeff_pad_sequence(chunk_list, batch_first=True).transpose(1, 2) + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + padded_mask_after_cnn = qeff_pad_sequence( + [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn], + batch_first=True, + padding_value=False, + ) + padded_feature = padded_feature.unsqueeze(1) + + padded_embeds = [] + for chunk in padded_feature.split(self.conv_chunksize, dim=0): + padded_embed = F.gelu(self.conv2d1(chunk)) + padded_embed = F.gelu(self.conv2d2(padded_embed)) + padded_embed = F.gelu(self.conv2d3(padded_embed)) + padded_embeds.append(padded_embed) + padded_embed = torch.cat(padded_embeds, dim=0) + b, c, f, t = padded_embed.size() + padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)) + + positional_embedding = ( + self.positional_embedding.positional_embedding[: padded_embed.shape[1], :] + .unsqueeze(0) + .to(padded_embed.dtype) + ) + padded_embed = padded_embed + positional_embedding + # Keep only valid timesteps per chunk (equivalent to padded_embed[padded_mask_after_cnn]) + # while avoiding boolean indexing that exports through NonZero. + valid_hidden_states = [] + for i in range(padded_embed.shape[0]): + seq_len = int(feature_lens_after_cnn[i].item()) + valid_hidden_states.append(padded_embed[i, :seq_len, :]) + hidden_states = torch.cat(valid_hidden_states, dim=0) + cu_chunk_lens = [0] + window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2)) + for cnn_len in aftercnn_lens: + cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) + remainder = cnn_len % window_aftercnn + if remainder != 0: + cu_chunk_lens += [remainder] + cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32) + + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens, + ) + hidden_states = layer_outputs[0] + + hidden_states = self.ln_post(hidden_states) + hidden_states = self.proj1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.proj2(hidden_states) + return BaseModelOutputWithPooling(last_hidden_state=hidden_states) + + +class QEffQwen3OmniMoeVisionAttention(Qwen3OmniMoeVisionAttention): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + + # Create index grids + seq_len = attention_mask.shape[-1] + rows = torch.arange(seq_len).view(1, -1) + cols = torch.arange(seq_len).view(-1, 1) + + # Prepare start and end indices + start = cu_seqlens[:-1].view(-1, 1, 1) + end = cu_seqlens[1:].view(-1, 1, 1) + + # Create block masks using broadcasting + row_mask = (rows >= start) & (rows < end) + col_mask = (cols >= start) & (cols < end) + block_mask = row_mask & col_mask # shape: (num_blocks, seq_len, seq_len) + + # Combine all blocks into one mask + final_mask = torch.ones((seq_len, seq_len), dtype=torch.float32) + final_mask[block_mask.any(dim=0)] = 0 + + final_mask = torch.where(final_mask == 1.0, torch.finfo(q.dtype).min, final_mask) + + attention_mask[0] = final_mask + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + cache_kwargs: Optional[Dict[str, Any]] = None, + layer_idx: int = None, + past_key_value: Optional[Cache] = None, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class QEffQwen3OmniMoeThinkerTextAttention(Qwen3OmniMoeThinkerTextAttention): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + sin_cached=None, + cos_cached=None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + bsz, q_len, _ = hidden_states.size() + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + # + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, + key_states, + cos_cached, + sin_cached, + position_ids[1:], + self.config.rope_scaling["mrope_section"], + ) + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin_cached, + "cos": cos_cached, + "batch_index": batch_index, + "position_ids": position_ids[0], + } + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + cache_kwargs=cache_kwargs, + layer_idx=self.layer_idx, + past_key_values=past_key_values, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_values + + +class QEffQwen3OmniMoeThinkerTextDecoderLayer(Qwen3OmniMoeThinkerTextDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + comp_ctx_lengths=comp_ctx_lengths, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + sin_cached=sin_cached, + cos_cached=cos_cached, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + return outputs + + +class QEffQwen3OmniMoeThinkerTextModel(Qwen3OmniMoeThinkerTextModel): + def __qeff_init__(self): + self.rotary_emb = QEffQwen3OmniMoeThinkerTextRotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + visual_pos_masks: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if self.config.use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask( + position_ids=position_ids[0], target_length=target_length, sliding_window=None + ) + + hidden_states = inputs_embeds + # + position_embeddings = self.rotary_emb(hidden_states, position_ids[1:]) + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + layer_idx = 0 + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if deepstack_visual_embeds is not None and layer_idx in range(deepstack_visual_embeds.shape[0]): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + layer_idx += 1 + + hidden_states = self.norm(hidden_states) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return (hidden_states, past_key_values) + + def _deepstack_process( + self, + hidden_states: torch.Tensor, + visual_pos_masks: torch.Tensor, + visual_embeds: torch.Tensor, + ): + visual_pos_masks = visual_pos_masks.unsqueeze(-1).expand(-1, -1, self.config.hidden_size) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + hidden_states = hidden_states.clone() + mixed_embeds = hidden_states + visual_embeds + + local_this = torch.where(visual_pos_masks, mixed_embeds, hidden_states) + + return local_this + + +class QEffQwen3OmniMoeThinkerTextTopKRouter(Qwen3OmniMoeThinkerTextTopKRouter): + def forward(self, hidden_states): + + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + # router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) + prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype) + top_w, top_i = torch.topk(prob, self.top_k, dim=-1) + top_w = top_w / torch.einsum("bi->b", top_w)[:, None] + top_w = top_w.to(hidden_states.dtype) + + return top_w, top_i, prob + + +class QEffPrefillChunkedQwen3OmniMoeThinkerTextSparseMoeBlock(Qwen3OmniMoeThinkerTextSparseMoeBlock): + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + B, S, H = hidden_states.shape + T = B * S + x = hidden_states.view(T, H) + act = getattr(self.experts, "act_fn", F.silu) + + # router_logits = self.gate(x) # [T, E] + # prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype) + # top_w, top_i = torch.topk(prob, self.top_k, dim=-1) + # top_w = top_w / torch.einsum("bi->b", top_w)[:, None] + # top_w = top_w.to(hidden_states.dtype) + + top_w, top_i, prob = self.gate(x) + routing_weights = torch.zeros((T, self.gate.num_experts), dtype=x.dtype) + routing_weights.scatter_(1, top_i, top_w) + + expert_out = torch.zeros_like(x, dtype=x.dtype) + + for e in range(self.gate.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) + + W_gate_up_e = self.experts.gate_up_proj[e] # [H, 2I] + W_dn_e = self.experts.down_proj[e] # [I, H] + # + gate_up = x @ W_gate_up_e.T # [T, 2I] + + I2 = gate_up.shape[-1] // 2 + gate = gate_up[:, :I2] # [T, I] + up = gate_up[:, I2:] # [T, I] + intermediate = up * act(gate) + down = intermediate @ W_dn_e.T + masked_down = torch.where( + routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out, dtype=down.dtype) + ) # TODO: verify and remove + expert_out += masked_down + expert_out = expert_out.to(x.dtype).view(B, S, H) + return expert_out + + +class QEffQwen3OmniMoeThinkerForConditionalGeneration(Qwen3OmniMoeThinkerForConditionalGeneration): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + output = Qwen3OmniMoeThinkerCausalLMOutputWithPast( + # last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.last_hidden_state, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + return output if return_dict else output.to_tuple() + + +class QEffQwen3OmniMoeThinkerTextSparseMoeBlock(Qwen3OmniMoeThinkerTextSparseMoeBlock): + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + B, S, H = hidden_states.shape + T = B * S + x = hidden_states.view(T, H) + + # router_logits = self.gate(x) + # prob = F.softmax(router_logits, dim=-1, dtype=torch.float) + # top_w, top_i = torch.topk(prob, self.top_k, dim=-1) + # top_w = top_w / torch.einsum("bi->b", top_w)[:, None] + # top_w = top_w.to(x.dtype) + top_w, top_i, prob = self.gate(x) + idx = top_i.reshape(-1) + w_up = self.experts.gate_up_proj.index_select(0, idx) + w_dn = self.experts.down_proj.index_select(0, idx) + + xk = x.unsqueeze(1).expand(-1, self.gate.top_k, -1).contiguous() + xk = xk.view(-1, 1, H) + gate_up = torch.bmm(xk, w_up.transpose(1, 2)) + I2 = gate_up.size(-1) + half = I2 // 2 + gate, up = gate_up[..., :half], gate_up[..., half:] + intermediate = up * self.experts.act_fn(gate) + experts_out = torch.bmm(intermediate, w_dn.transpose(1, 2)) + experts_out = experts_out.view(T, self.gate.top_k, H) * top_w.unsqueeze(-1) + experts_out = torch.einsum("bnd->bd", experts_out) + return experts_out.view(B, S, H) + + +class QEffQwen3OmniMoeVisionEncoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.model.vision_model = self.model.thinker.visual + + def get_submodules_for_export(self) -> Type[nn.Module]: + """ + Return the set of class used as the repeated layer across the model for subfunction extraction. + Notes: + This method should return the *class object* (not an instance). + Downstream code can use this to find/build subfunctions for repeated blocks. + """ + # import ipdb; ipdb.set_trace() + return {Qwen3OmniMoeVisionBlock} + + def forward(self, pixel_values, image_grid_thw): + image_embeds, deepstack_feature_lists = self.model.thinker.visual(pixel_values, grid_thw=image_grid_thw) + bs = image_grid_thw.shape[0] + split_size = torch.floor_divide(torch.tensor(image_embeds.size(0)), bs) + image_embeds = image_embeds.reshape(bs, split_size, image_embeds.size(1)) + deepstack_features = torch.stack( + [feature.reshape(bs, split_size, feature.size(1)) for feature in deepstack_feature_lists], + dim=0, # new axis for "features" + ) + + return image_embeds, deepstack_features + + +class QEffQwen3OmniMoeAudioEncoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.model.vision_model = self.model.thinker.visual + + def get_submodules_for_export(self) -> Type[nn.Module]: + """ + Return the set of class used as the repeated layer across the model for subfunction extraction. + Notes: + This method should return the *class object* (not an instance). + Downstream code can use this to find/build subfunctions for repeated blocks. + """ + # import ipdb; ipdb.set_trace() + return {QEffQwen3OmniMoeAudioEncoder} + + def forward(self, input_features, feature_attention_mask): + + batch_size = input_features.shape[0] + feature_lens = torch.full( + (batch_size,), input_features.shape[-1], dtype=torch.long, device=input_features.device + ) + input_features = input_features.permute(0, 2, 1).reshape(-1, input_features.shape[1]).permute(1, 0) + audio_outputs = self.model.thinker.audio_tower( + input_features, + feature_lens=feature_lens, + return_dict=True, + ) + audio_embeds = audio_outputs.last_hidden_state + audio_split_size = torch.floor_divide(torch.tensor(audio_embeds.size(0)), batch_size) + audio_embeds = audio_embeds.reshape(batch_size, audio_split_size, audio_embeds.size(1)) + + return audio_embeds + + +class QEffQwen3OmniMoeDecoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + # self.language_model = self.model.thinker + + def get_submodules_for_export(self) -> Type[nn.Module]: + """ + Return the set of class used as the repeated layer across the model for subfunction extraction. + Notes: + This method should return the *class object* (not an instance). + Downstream code can use this to find/build subfunctions for repeated blocks. + """ + return {QEffQwen3OmniMoeThinkerTextDecoderLayer} + + def forward( + self, + input_ids, + vision_embeds, + deepstack_features, + audio_embeds, + position_ids, + image_idx, + audio_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + comp_ctx_lengths: Optional[List[int]] = None, + ): + inputs_embeds = self.model.thinker.get_input_embeddings()(input_ids) + B, N, C = inputs_embeds.shape + selected = input_ids == self.model.thinker.config.image_token_id + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] + + num_features, bs, split_size, C = deepstack_features.shape + x = deepstack_features.reshape(num_features, bs * split_size, C) + deepstack_features_expanded = x[:, indices1, :] + image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) + # inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) + inputs_embeds = image_input_embeds + + selected_audio = input_ids == self.model.thinker.config.audio_token_id + audio_indices1 = selected_audio.to(torch.int64).cumsum(1) - 1 + audio_indices1 = torch.where(audio_indices1 != -1, audio_indices1 + audio_idx, audio_indices1) + audio_indices0 = torch.arange(selected_audio.shape[0], device=input_ids.device).view(-1, 1) + audio_features_expanded = audio_embeds.reshape(-1, C).unsqueeze(0)[audio_indices0, audio_indices1] + audio_input_embeds = torch.where(selected_audio.unsqueeze(-1), audio_features_expanded, inputs_embeds) + inputs_embeds = audio_input_embeds + + image_mask = selected.clone() + + visual_pos_masks = None + deepstack_visual_embeds = None + + if image_mask is not None: + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack_features_expanded + + outputs = self.model.thinker( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + # comp_ctx_lengths=comp_ctx_lengths, + # batch_index=batch_index, + use_cache=True, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + ) + logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) + + hidden_states = outputs.hidden_states[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] + logits = self.model.thinker.lm_head(hidden_states) + image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + audio_idx = (audio_indices1.max() + 1).unsqueeze(0).unsqueeze(0) + return logits, vision_embeds, deepstack_features, audio_embeds, image_idx, audio_idx, outputs.past_key_values + + +class QEffQwen3OmniMoeForConditionalGeneration(Qwen3OmniMoeForConditionalGeneration): + # def __qeff_init__(self, model): + # super().__init__() + # self.model = model + # self.language_model = self.thinker + + def get_submodules_for_export(self) -> Type[nn.Module]: + """ + Return the set of class used as the repeated layer across the model for subfunction extraction. + Notes: + This method should return the *class object* (not an instance). + Downstream code can use this to find/build subfunctions for repeated blocks. + """ + return {QEffQwen3OmniMoeThinkerTextDecoderLayer} + + def get_qeff_vision_encoder(self): + return QEffQwen3OmniMoeVisionEncoderWrapper(self) + + def get_qeff_audio_encoder(self): + return QEffQwen3OmniMoeAudioEncoderWrapper(self) + + def get_qeff_language_decoder(self): + return QEffQwen3OmniMoeDecoderWrapper(self) + + def forward( + self, + input_ids, + position_ids, + # input_features, + pixel_values, + image_grid_thw, + # feature_attention_mask, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): + + image_embeds, deepstack_feature_lists = self.thinker.visual(pixel_values, grid_thw=image_grid_thw) + bs = image_grid_thw.shape[0] + split_size = torch.floor_divide(torch.tensor(image_embeds.size(0)), bs) + image_embeds = image_embeds.reshape(bs, split_size, image_embeds.size(1)) + deepstack_features = torch.stack( + [feature.reshape(bs, split_size, feature.size(1)) for feature in deepstack_feature_lists], + dim=0, # new axis for "features" + ) + inputs_embeds = self.thinker.get_input_embeddings()(input_ids) + B, N, C = inputs_embeds.shape + selected = input_ids == self.thinker.config.image_token_id + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = image_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] + + num_features, bs, split_size, C = deepstack_features.shape + x = deepstack_features.reshape(num_features, bs * split_size, C) + deepstack_features_expanded = x[:, indices1, :] + image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) + # inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) + inputs_embeds = image_input_embeds + + image_mask = selected.clone() + + visual_pos_masks = None + deepstack_visual_embeds = None + + if image_mask is not None: + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack_features_expanded + + outputs = self.thinker( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + # comp_ctx_lengths=comp_ctx_lengths, + # batch_index=batch_index, + use_cache=True, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + ) + logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) + + hidden_states = outputs.hidden_states[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index] + logits = self.thinker.lm_head(hidden_states) + image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + return logits, pixel_values, image_idx, outputs.past_key_values + + def get_dummy_inputs( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + continuous_batching: bool = False, + **kwargs, + ): + inputs_shapes = {} + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + # vision_size = 1024 + vision_size = 187 + audio_size = int(_get_feat_extract_output_lengths(torch.tensor([290])).item()) + inputs_shapes["vision_embeds"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + vision_size, + self.thinker.config.vision_config.out_hidden_size, + ) + inputs_shapes["audio_embeds"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + audio_size, + self.thinker.config.vision_config.out_hidden_size, + ) + inputs_shapes["image_grid_thw"] = (1, 1, 22, 34) + inputs_shapes["position_ids"] = ( + 3, + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + inputs_shapes["pixel_values"] = (748, 1536) + inputs_shapes["image_idx"] = (1, 1) + inputs_shapes["audio_idx"] = (1, 1) + inputs_shapes["image_sizes"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 2) + inputs_shapes["deepstack_features"] = ( + len(self.thinker.config.vision_config.deepstack_visual_indexes), + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + vision_size, + self.thinker.config.vision_config.out_hidden_size, + ) + + inputs_shapes["input_features"] = (1, 128, 290) + inputs_shapes["feature_attention_mask"] = (1, 290) + vision_inputs = {} + audio_inputs = {} + lang_inputs = {} + vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) + vision_inputs["image_grid_thw"] = torch.zeros((inputs_shapes["image_grid_thw"]), dtype=torch.int64) + audio_inputs["input_features"] = torch.zeros((inputs_shapes["input_features"]), dtype=torch.float32) + audio_inputs["feature_attention_mask"] = torch.ones( + (inputs_shapes["feature_attention_mask"]), dtype=torch.int64 + ) + + lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) + lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) + lang_inputs["position_ids"] = ( + ( + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) + ) + .unsqueeze(0) + .repeat(4, 1, 1) + ) + lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + lang_inputs["audio_idx"] = torch.zeros((inputs_shapes["audio_idx"]), dtype=torch.int64) + lang_inputs["deepstack_features"] = torch.zeros((inputs_shapes["deepstack_features"]), dtype=torch.float32) + lang_inputs["audio_embeds"] = torch.zeros((inputs_shapes["audio_embeds"]), dtype=torch.float32) + # Add data for KV + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + + kv_cache_shape = get_padding_shape_from_config( + config=self.thinker.config.text_config, + batch_size=fbs if continuous_batching else bs, + # seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + seq_len=512, + ) + + lang_inputs["past_key_values"] = [[] for _ in range(self.thinker.config.text_config.num_hidden_layers)] + for i in range(self.thinker.config.text_config.num_hidden_layers): + for kv in ["key", "value"]: + lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.int8) + inputs = {} + if kv_offload: + inputs["vision"] = vision_inputs + inputs["lang"] = lang_inputs + inputs["audio"] = audio_inputs + else: + lang_inputs.pop("vision_embeds") + lang_inputs.pop("deepstack_features") + lang_inputs.pop("audio_embeds") + inputs = {**vision_inputs, **lang_inputs} + return inputs + + def get_specializations( + self, + batch_size: int, + prefill_seq_len: int, + ctx_len: int, + img_size: None, + height: int = None, + width: int = None, + time: int = 1, + # dimensions: List = None, + num_frames: int = 1, + kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, + **compiler_options, + ): + comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill", None) + comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode", None) + if height is None or width is None: + height = constants.QWEN3_VL_HEIGHT + width = constants.QWEN3_VL_WIDTH + logger.warning( + f"Setting height and width to be {height} and {width} respectively, as it was neither passed nor found in vision_config" + ) + prefill_seq_len = prefill_seq_len if prefill_seq_len else 128 + ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN + channel = 3 + patch_size = self.thinker.config.vision_config.patch_size + temporal_patch_size = self.thinker.config.vision_config.temporal_patch_size + + IMAGE_FACTOR = 28 + MIN_PIXELS = 4 * 28 * 28 + MAX_PIXELS = 16384 * 28 * 28 + MAX_RATIO = 200 + + def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + def smart_resize( + height: int, + width: int, + factor: int = IMAGE_FACTOR, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, + ) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + # import ipdb; ipdb.set_trace() + resized_height, resized_width = smart_resize(height=height, width=width) + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size + ##Added extra## + grid_w += 1 + grid_height = grid_h * grid_w + grid_width = patch_size * patch_size * temporal_patch_size * channel + vision_size = grid_height // 4 + vision_size = vision_size * num_frames * time + audio_size = int( + compiler_options.pop("audio_size", _get_feat_extract_output_lengths(torch.tensor([290])).item()) + ) + grid_height = grid_height * time * batch_size + + vision = [ + { + "batch_size": batch_size, + "vision_size": vision_size, + "grid_height": grid_height, + "grid_width": grid_width, + "time": time, + "grid_h": grid_h, + "grid_w": grid_w, + "num_feature_layers": len(self.thinker.config.vision_config.deepstack_visual_indexes), + } + ] + + audio = [{"batch_size": batch_size}] + + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + "audio_size": audio_size, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "vision_batch_size": batch_size, + "num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes), + } + + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "vision_size": vision_size, + "audio_size": audio_size, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "vision_batch_size": batch_size, + "num_feature_layers": len(self.config.vision_config.deepstack_visual_indexes), + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang.append(lang_decode) + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + "audio_size": audio_size, + "vision_batch_size": batch_size, + "grid_height": grid_height, + "grid_width": grid_width, + "grid_h": grid_h, + "grid_w": grid_w, + "time": time, + "num_feature_layers": len(self.thinker.config.vision_config.deepstack_visual_indexes), + } + + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "vision_size": vision_size, + "audio_size": audio_size, + "vision_batch_size": batch_size, + "grid_height": grid_height, + "grid_width": grid_width, + "time": time, + "grid_h": grid_h, + "grid_w": grid_w, + "num_feature_layers": len(self.thinker.config.vision_config.deepstack_visual_indexes), + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + # lang = [lang_prefill, lang_decode] + lang = [lang_decode] + + specializations = {} + + if kv_offload: + specializations["vision"] = vision + specializations["lang"] = lang + specializations["audio"] = audio + return specializations, compiler_options + else: + lang[0].pop("vision_size") + lang[0].pop("audio_size") + # lang[1].pop("vision_size") + return lang, compiler_options + + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): + # Define dynamic axes + num_layers = self.thinker.config.text_config.num_hidden_layers + vision_dynamic_axes = { + "pixel_values": {0: "grid_height", 1: "grid_width"}, + "image_grid_thw": {0: "batch_size", 1: "time", 2: "grid_h", 3: "grid_w"}, + "deepstack_features": {0: "num_feature_layers", 1: "batch_size", 2: "vision_size"}, + } + + lang_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {1: "batch_size", 2: "seq_len"}, + "vision_embeds": {0: "vision_batch_size", 1: "vision_size"}, + "deepstack_features": {0: "num_feature_layers", 1: "vision_batch_size", 2: "vision_size"}, + "audio_embeds": {0: "vision_batch_size", 1: "audio_size"}, + } + + audio_dynamic_axes = { + "input_features": {0: "batch_size"}, + "feature_attention_mask": {0: "batch_size"}, + "audio_embeds": {0: "vision_batch_size", 1: "audio_size"}, + } + + for i in range(num_layers): + lang_dynamic_axes[f"past_key.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + lang_dynamic_axes[f"past_value.{i}"] = { + 0: "full_batch_size" if continuous_batching else "batch_size", + 2: "ctx_len", + } + + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} + + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + + dynamic_axes = {} + + if kv_offload: + dynamic_axes["vision"] = vision_dynamic_axes + dynamic_axes["lang"] = lang_dynamic_axes + dynamic_axes["audio"] = audio_dynamic_axes + else: + lang_dynamic_axes.pop("vision_embeds") + lang_dynamic_axes.pop("deepstack_features") + lang_dynamic_axes.pop("audio_embeds") + vision_dynamic_axes.pop("deepstack_features") + dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes} + return dynamic_axes + + def get_output_names(self, kv_offload: bool = False): + vision_output_names = ["vision_embeds"] + vision_output_names.append("deepstack_features") + audio_output_names = ["audio_embeds"] + lang_output_names = ["logits"] + for i in range(self.thinker.config.text_config.num_hidden_layers): + for kv in ["key", "value"]: + lang_output_names.append(f"past_{kv}.{i}_RetainedState") + + output_names = {} + if kv_offload: + lang_output_names.insert(1, "vision_embeds_RetainedState") + lang_output_names.insert(2, "deepstack_features_RetainedState") + lang_output_names.insert(3, "audio_embeds_RetainedState") + lang_output_names.insert(4, "image_idx_output") + lang_output_names.insert(5, "audio_idx_output") + output_names["vision"] = vision_output_names + output_names["lang"] = lang_output_names + output_names["audio"] = audio_output_names + else: + lang_output_names.insert(1, "pixel_values_RetainedState") + lang_output_names.insert(2, "image_idx_output") + return lang_output_names + return output_names + + def prepare_inputs_for_generation(self, inputs, prefill_seq_len=128, batch_size=1): + + input_ids_length = inputs["input_ids"].shape[1] + inputs["position_ids"] = torch.arange(input_ids_length).view(1, 1, input_ids_length).expand(-1, batch_size, -1) + pos_ids, rope_deltas = self.thinker.get_rope_index( + inputs["input_ids"], + None if "image_grid_thw" not in inputs else inputs["image_grid_thw"], + None if "video_grid_thw" not in inputs else inputs["video_grid_thw"], + inputs["attention_mask"], + False, + torch.sum(inputs["feature_attention_mask"], dim=1), + ) + + inputs["position_ids"] = torch.cat((inputs["position_ids"], pos_ids), dim=0).to(torch.int64) + + num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + + inputs["position_ids"] = F.pad( + inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1 + ) + + inputs.pop("image_grid_thw", None) + return inputs + + def get_inputs_info(self): + return [ + IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "image_size", "image_size")), + ] diff --git a/QEfficient/transformers/quantizers/quantizer_awq.py b/QEfficient/transformers/quantizers/quantizer_awq.py index b7199a71e..11abd8380 100644 --- a/QEfficient/transformers/quantizers/quantizer_awq.py +++ b/QEfficient/transformers/quantizers/quantizer_awq.py @@ -7,7 +7,7 @@ import torch from transformers.quantizers.quantizer_awq import AwqQuantizer -from transformers.utils.quantization_config import AwqBackendPackingMethod, AwqConfig, AWQLinearVersion +from transformers.utils.quantization_config import AwqBackend, AwqConfig, AwqFormat from QEfficient.transformers.quantizers.awq import WQLinear_GEMM from QEfficient.transformers.quantizers.quantizer_utils import ( @@ -24,23 +24,20 @@ def post_init(self): Safety checker that arguments are correct """ - if self.backend not in [AwqBackendPackingMethod.AUTOAWQ]: + if self.backend not in [AwqBackend.LEGACY_AWQ]: raise ValueError( - f"Only quantization backend {AwqBackendPackingMethod.AUTOAWQ} is supported - not recognized backend {self.backend}" + f"Only quantization backend {AwqBackend.LEGACY_AWQ} is supported - not recognized backend {self.backend}" ) - if isinstance(self.version, str): - self.version = AWQLinearVersion.from_str(self.version) - if self.version not in [AWQLinearVersion.GEMM]: - raise ValueError( - f"Only {AWQLinearVersion.GEMM} version in supported - not recognized version {self.version}" - ) + self.version = self.version.lower() + # self.version = AwqFormat(self.version) + breakpoint() + if self.version not in [AwqFormat.GEMM]: + raise ValueError(f"Only {AwqFormat.GEMM} version in supported - not recognized version {self.version}") - do_fuse = getattr(self, "do_fuse", None) - fuse_max_seq_len = getattr(self, "fuse_max_seq_len", None) - if do_fuse or fuse_max_seq_len is not None: + if self.do_fuse or self.fuse_max_seq_len is not None: raise ValueError( - f"fused modules are not supported, got do_fuse={do_fuse}, fuse_max_seq_len={fuse_max_seq_len}" + f"fused modules are not supported, got do_fuse={self.do_fuse}, fuse_max_seq_len={self.fuse_max_seq_len}" ) if self.bits != 4: @@ -66,9 +63,6 @@ def update_torch_dtype(self, torch_dtype): logger.warning(f"Requested dtype {torch_dtype} is not supported, overriding to None") return None - def update_dtype(self, dtype): - return self.update_torch_dtype(dtype) - def _process_model_before_weight_loading(self, model, **kwargs): self.modules_to_not_convert = get_keys_to_not_convert(model) diff --git a/examples/Multimodal/qwen3_Omni/example.py b/examples/Multimodal/qwen3_Omni/example.py new file mode 100644 index 000000000..ee32ff188 --- /dev/null +++ b/examples/Multimodal/qwen3_Omni/example.py @@ -0,0 +1,153 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +import transformers +from PIL import Image +from qwen_omni_utils import process_mm_info + +# from qwen_vl_utils import process_vision_info +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForMultimodalLM + +model_id = "Qwen/Qwen3-Omni-30B-A3B-Instruct" +config = AutoConfig.from_pretrained(model_id) + +# config.talker_config.text_config.num_hidden_layers = 2 +# config.thinker_config.text_config.num_hidden_layers = 2 +# config.thinker_config.vision_config.deepstack_visual_indexes = [8] +# config.thinker_config.vision_config.depth = 9 + +config.enable_audio_output = False +config.torch_dtype = "float32" +qeff_model = QEFFAutoModelForMultimodalLM.from_pretrained( + model_id, attn_implementation="eager", kv_offload=True, config=config +) +# import ipdb; ipdb.set_trace() +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) +### use skip_vision=Ture, if want to run only text, or false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + ## Set Batch_Size ## + batch_size = 1 + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + # mdts_mos=1, + # use_onnx_subfunctions=True, + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about yourself."}, + ], + }, + ] + + messages = [messages] * batch_size + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + batch_size = 1 + ## Vision + Text ## + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + # mdts_mos=1, + # use_onnx_subfunctions=True, + ) + + ### IMAGE + TEXT ### + image_url = "https://picsum.photos/id/237/536/354" + + image = Image.open(requests.get(image_url, stream=True).raw) + + messages_1 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Descibe all the colors seen in the image."}, + ], + }, + ] + + messages = [messages_1] * batch_size + + conversation = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "audio", "audio": "/home/mohisoni/omni/cough.wav"}, + {"type": "text", "text": "What can you see and hear? Answer in one short sentence."}, + ], + }, + ] + + # Set whether to use audio in video + USE_AUDIO_IN_VIDEO = False + + # Preparation for inference + text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) + + audios, images, videos = process_mm_info(conversation, use_audio_in_video=USE_AUDIO_IN_VIDEO) + inputs = processor( + text=text, + audio=audios, + images=images, + videos=videos, + return_tensors="pt", + padding=True, + use_audio_in_video=USE_AUDIO_IN_VIDEO, + ) + # inputs = inputs.to(qeff_model.device).to(qeff_model.dtype) + inputs = inputs.to("cpu") + + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) diff --git a/pyproject.toml b/pyproject.toml index 003143d8f..79fe18204 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "ftfy==6.3.1", "imageio==2.37.2", "imageio-ffmpeg==0.6.0", + "qwen-omni-utils==0.0.9", "torch==2.7.0; platform_machine=='aarch64'", # Specifying torch cpu package URL per python version, update the list once pytorch releases whl for python>3.11 "torch@https://download.pytorch.org/whl/cpu/torch-2.4.1%2Bcpu-cp38-cp38-linux_x86_64.whl ; python_version=='3.8' and platform_machine=='x86_64'",