diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 0e1118407a..6ebccdfbf8 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -10,7 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple import torch -from transformers.cache_utils import DynamicCache, DynamicLayer, EncoderDecoderCache, HybridCache, HybridChunkedCache +from transformers.cache_utils import Cache, CacheLayerMixin, EncoderDecoderCache, HybridCache, HybridChunkedCache from QEfficient.customop import ( CtxGatherFunc, @@ -54,7 +54,47 @@ def _get_invalid_idx_value(cls): return 0 -class QEffDynamicLayer(DynamicLayer): +class QEffDynamicLayer(CacheLayerMixin): + is_sliding = False + + def __init__(self): + super().__init__() + + def lazy_initialization(self, key_states: torch.Tensor): + self.dtype = key_states.dtype + self.device = key_states.device + self.keys = torch.tensor([], dtype=self.dtype, device=self.device) + self.values = torch.tensor([], dtype=self.dtype, device=self.device) + self.is_initialized = True + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + kv_offset = 0 + query_length = cache_position.shape[0] + kv_length = self.get_seq_length() + query_length + return kv_length, kv_offset + + def get_seq_length(self) -> int: + if self.keys is None or self.keys.numel() == 0: + return 0 + return self.keys.shape[-2] + + def get_max_cache_shape(self) -> int: + return -1 + + @classmethod + def from_tensors(cls, key_states: torch.Tensor, value_states: torch.Tensor) -> "QEffDynamicLayer": + layer = cls() + layer.keys = key_states + layer.values = value_states + layer._mark_initialized(key_states) + return layer + + def _mark_initialized(self, reference_states: torch.Tensor) -> None: + if not self.is_initialized: + self.dtype = reference_states.dtype + self.device = reference_states.device + self.is_initialized = True + def read_only(self, cache_kwargs): """ Reads the `key_states` and `value_states` for the layer. @@ -68,6 +108,8 @@ def read_only(self, cache_kwargs): """ # Gather k_out, v_out = self.keys, self.values + if k_out is not None: + self._mark_initialized(k_out) position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) ctx_len = cache_kwargs.get("CCL", k_out.shape[2]) @@ -109,6 +151,8 @@ def read_only_blockedKV(self, start_index, end_index, cache_kwargs): """ # Gather k_out, v_out = self.keys, self.values + if k_out is not None: + self._mark_initialized(k_out) position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) batch, num_kv_heads, _, _ = k_out.shape @@ -150,7 +194,9 @@ def write_only(self, key_states, value_states, cache_kwargs): if self.keys is None: self.keys = key_states self.values = value_states + self._mark_initialized(self.keys) else: + self._mark_initialized(self.keys) position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs @@ -189,8 +235,10 @@ def update( if self.keys is None: self.keys = key_states self.values = value_states + self._mark_initialized(self.keys) k_out, v_out = self.keys, self.values else: + self._mark_initialized(self.keys) position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs @@ -252,8 +300,10 @@ def update3D( if self.keys is None: self.keys = key_states self.values = value_states + self._mark_initialized(self.keys) k_out, v_out = self.keys, self.values else: + self._mark_initialized(self.keys) position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) @@ -293,7 +343,7 @@ def update3D( return k_out, v_out -class QEffDynamicCache(DynamicCache): +class QEffDynamicCache(Cache): """ A cache that grows dynamically as more tokens are generated. This is the default for generative models. @@ -307,15 +357,46 @@ class QEffDynamicCache(DynamicCache): """ def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs): - # Remove layer_classes if present to avoid duplicate argument + # Remove cache-layer construction args if present to avoid duplicate arguments. kwargs.pop("layer_classes", None) - from transformers.cache_utils import Cache # Import here to avoid circular import - - Cache.__init__(self, layer_classes=QEffDynamicLayer, *args, **kwargs) + kwargs.pop("layers", None) + kwargs.pop("layer_class_to_replicate", None) + + try: + # transformers>=4.57 + Cache.__init__(self, *args, layer_class_to_replicate=QEffDynamicLayer, **kwargs) + except TypeError: + # transformers<=4.56 + Cache.__init__(self, *args, layer_classes=QEffDynamicLayer, **kwargs) if ddp_cache_data is not None: for key_states, value_states in ddp_cache_data: self.layers.append(QEffDynamicLayer.from_tensors(key_states, value_states)) + def append_new_layers(self, layer_idx: int) -> None: + while len(self.layers) <= layer_idx: + self.layers.append(QEffDynamicLayer()) + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "QEffDynamicCache": + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: + legacy_cache = () + for layer in self.layers: + legacy_cache += ((layer.keys, layer.values),) + return legacy_cache + + def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int: + """ + Keep backward-compatible call shape while deferring to upstream implementation. + """ + return super().get_seq_length(layer_idx) + def read_only(self, layer_idx, cache_kwargs): """ Reads the `key_states` and `value_states` for the layer `layer_idx`. @@ -405,10 +486,7 @@ def from_legacy_cache( cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None ) -> "EncoderDecoderCache": """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" - cache = cls( - self_attention_cache=QEffDynamicCache(), - cross_attention_cache=QEffDynamicCache(), - ) + cache = cls(QEffDynamicCache(), QEffDynamicCache()) if past_key_values is not None: for layer_idx in range(len(past_key_values)): key_states, value_states = past_key_values[layer_idx][:2] @@ -419,6 +497,18 @@ def from_legacy_cache( cache.is_updated[layer_idx] = True return cache + def to_legacy_cache(self): + self_attn_legacy = self.self_attention_cache.to_legacy_cache() + cross_attn_legacy = self.cross_attention_cache.to_legacy_cache() + + legacy_cache = () + for layer_idx, self_attn_layer in enumerate(self_attn_legacy): + if layer_idx < len(cross_attn_legacy): + legacy_cache += (self_attn_layer + cross_attn_legacy[layer_idx],) + else: + legacy_cache += (self_attn_layer,) + return legacy_cache + # TODO:This function will be depercated in future. class QEffHybridCache(HybridCache): @@ -447,7 +537,7 @@ def __len__(self): """ return len(self.key_cache) - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # TODO: deprecate this function in favor of `cache_position` is_empty_layer = ( @@ -531,7 +621,7 @@ def __len__(self): """ return len(self.key_cache) - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # TODO: deprecate this function in favor of `cache_position` is_empty_layer = ( @@ -663,7 +753,7 @@ def __len__(self): """ return len(self.key_cache) - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # TODO: deprecate this function in favor of `cache_position` is_empty_layer = ( @@ -783,7 +873,7 @@ def __len__(self): """ return len(self.key_cache) - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position: Optional[torch.LongTensor] = None) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # TODO: deprecate this function in favor of `cache_position` is_empty_layer = ( diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index 81ce9acf40..a29d0e0966 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -191,6 +191,7 @@ ] ) + # This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc. DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"} diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index e70f32818f..26080a59a8 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -104,7 +104,7 @@ def forward( alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, layer_past: Optional[Cache] = None, @@ -190,7 +190,7 @@ def forward( layer_past=layer_past, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, alibi=alibi, diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 3bed2d00ee..0d740c717e 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -123,7 +123,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, @@ -143,12 +143,12 @@ def forward( query_states, key_states, cos_cached, sin_cached, position_ids ) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -210,7 +210,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 8e2e823c7f..ac6de7de4c 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -130,7 +130,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, @@ -150,7 +150,7 @@ def forward( query_states, key_states, cos_cached, sin_cached, position_ids ) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin_cached, @@ -161,7 +161,7 @@ def forward( if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -177,7 +177,7 @@ def forward( attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_values class QEffGemma2DecoderLayer(Gemma2DecoderLayer): @@ -227,7 +227,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index f98bae2257..3d5a19bf96 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -214,7 +214,7 @@ def forward( position_embeddings: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, @@ -232,7 +232,7 @@ def forward( query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) - if past_key_value is not None: + if past_key_values is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " @@ -245,7 +245,7 @@ def forward( cos, sin = self.rotary_emb(value_states, seq_len=self.config.max_position_embeddings) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin, @@ -254,12 +254,12 @@ def forward( "position_ids": position_ids, "is_sliding": self.is_sliding, "sliding_window_pattern": self.config.sliding_window_pattern, - "sliding_window": past_key_value.sliding_window_len, + "sliding_window": past_key_values.sliding_window_len, } if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -330,7 +330,7 @@ def forward( position_embeddings=None, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, diff --git a/QEfficient/transformers/models/gpt2/modeling_gpt2.py b/QEfficient/transformers/models/gpt2/modeling_gpt2.py index 7de674cce9..1872e64ab1 100644 --- a/QEfficient/transformers/models/gpt2/modeling_gpt2.py +++ b/QEfficient/transformers/models/gpt2/modeling_gpt2.py @@ -64,7 +64,7 @@ class QEffGPT2Attention(GPT2Attention): def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -76,16 +76,16 @@ def forward( **kwargs, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, QEffEncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, QEffEncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values if is_cross_attention: if not hasattr(self, "q_attn"): @@ -98,7 +98,7 @@ def forward( attention_mask = encoder_attention_mask # Try to get key/value states from cache if possible - if past_key_value is not None and is_updated: + if past_key_values is not None and is_updated: key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values else: @@ -116,8 +116,8 @@ def forward( shape_q = (*query_states.shape[:-1], -1, self.head_dim) query_states = query_states.view(shape_q).transpose(1, 2) - if (past_key_value is not None and not is_cross_attention) or ( - past_key_value is not None and is_cross_attention and not is_updated + if (past_key_values is not None and not is_cross_attention) or ( + past_key_values is not None and is_cross_attention and not is_updated ): # save all key/value_layer to cache to be re-used for fast auto-regressive generation # Update the cache_kwargs with position_ids for Cloud AI 100 @@ -131,7 +131,7 @@ def forward( ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward attn_output, attn_weights = attention_interface( @@ -178,7 +178,7 @@ def forward( hidden_states = self.ln_1(hidden_states) attn_output, self_attn_weights = self.attn( hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_value, attention_mask=attention_mask, comp_ctx_lengths=comp_ctx_lengths, position_ids=position_ids, diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 6f4e1d8c43..d0b9283535 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -733,7 +733,7 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, @@ -752,7 +752,7 @@ def forward( query_states, key_states, cos_cached, sin_cached, position_ids ) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin_cached, @@ -764,14 +764,14 @@ def forward( "sliding_window": self.sliding_window, } if self.sliding_window is not None: - key_states, value_states = past_key_value.sliding_window_update_chunked( + key_states, value_states = past_key_values.sliding_window_update_chunked( key_states, value_states, self.layer_idx, cache_kwargs ) else: if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.full_cache_update_chunked( + key_states, value_states = past_key_values.full_cache_update_chunked( key_states, value_states, self.layer_idx, cache_kwargs ) @@ -805,7 +805,7 @@ def forward( attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_values class QEffPrefillOnlyGptOssAttention(GptOssAttention): @@ -817,7 +817,7 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, @@ -836,7 +836,7 @@ def forward( query_states, key_states, cos_cached, sin_cached, position_ids ) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin_cached, @@ -845,11 +845,11 @@ def forward( "position_ids": position_ids, "config": self.config, "is_sliding": self.sliding_window is not None, - "sliding_window": past_key_value.sliding_window_len, + "sliding_window": past_key_values.sliding_window_len, } if self.sliding_window is not None: - sliding_window_len = past_key_value.sliding_window_len - short_read_idx = torch.arange(past_key_value.key_cache[self.layer_idx].shape[2]) + sliding_window_len = past_key_values.sliding_window_len + short_read_idx = torch.arange(past_key_values.key_cache[self.layer_idx].shape[2]) read_idx = short_read_idx + torch.where( position_ids.max() > sliding_window_len - 1, position_ids.max() - sliding_window_len + 1, 0 ) @@ -859,7 +859,7 @@ def forward( v_cache = value_states[:, :, read_idx, :] else: k_cache, v_cache = key_states, value_states - _, _ = past_key_value.write_only(k_cache, v_cache, self.layer_idx, cache_kwargs) + _, _ = past_key_values.write_only(k_cache, v_cache, self.layer_idx, cache_kwargs) if self.sliding_window is not None: attention_mask = sliding_mask @@ -885,7 +885,7 @@ def forward( attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_values class QEffGptOssAttention(GptOssAttention): @@ -897,7 +897,7 @@ def forward( position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, @@ -916,7 +916,7 @@ def forward( query_states, key_states, cos_cached, sin_cached, position_ids ) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin_cached, @@ -925,12 +925,12 @@ def forward( "position_ids": position_ids, "config": self.config, "is_sliding": self.sliding_window is not None, - "sliding_window": past_key_value.sliding_window_len, + "sliding_window": past_key_values.sliding_window_len, } if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) if self.sliding_window is not None: attention_mask = sliding_mask @@ -953,7 +953,7 @@ def forward( attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_values class QEffGptOssDecoderLayer(GptOssDecoderLayer): @@ -981,7 +981,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, diff --git a/QEfficient/transformers/models/gptj/modeling_gptj.py b/QEfficient/transformers/models/gptj/modeling_gptj.py index a4c81dbecb..bbf621f106 100644 --- a/QEfficient/transformers/models/gptj/modeling_gptj.py +++ b/QEfficient/transformers/models/gptj/modeling_gptj.py @@ -223,7 +223,7 @@ def forward( else: past_length = past_key_values[0][0].size(-2) - if not self._use_flash_attention_2: + if not getattr(self, "_use_flash_attention_2", False): attention_mask = _create_causal_mask(position_ids, past_length, None) # # Prepare head mask if needed diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index c2af97f55d..81aa192945 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -116,7 +116,7 @@ def forward( hidden_states: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, @@ -136,7 +136,7 @@ def forward( query_states, key_states, cos_cached, sin_cached, position_ids ) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin_cached, @@ -147,7 +147,7 @@ def forward( if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward attn_output, attn_weights = attention_interface( @@ -176,7 +176,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, @@ -216,7 +216,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_values=past_key_values, + past_key_values=past_key_value, output_attentions=output_attentions, batch_index=batch_index, use_cache=use_cache, diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 82bb8533a6..40359e7c89 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -107,7 +107,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, @@ -130,7 +130,7 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb( query_states, key_states, cos_cached, sin_cached, position_ids ) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin_cached, @@ -142,7 +142,7 @@ def forward( if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -242,7 +242,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/QEfficient/transformers/models/grok_1/modeling_grok1.py b/QEfficient/transformers/models/grok_1/modeling_grok1.py index 1a1c919bb1..51bdaa4ea4 100644 --- a/QEfficient/transformers/models/grok_1/modeling_grok1.py +++ b/QEfficient/transformers/models/grok_1/modeling_grok1.py @@ -59,6 +59,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -87,8 +88,9 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len = past_key_value.get_seq_length(layer_idx) + kv_seq_len = past_key_value.get_seq_length(layer_idx, cache_position) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 5b501d36fa..00f97e24d2 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -193,7 +193,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: bool = False, @@ -216,25 +216,27 @@ def forward( value_states = self.v_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 query_states, key_states = qeff_apply_rotary_pos_emb( query_states, key_states, cos_cached, sin_cached, position_ids ) - if past_key_value is not None: + if past_key_values is not None: if num_kv_blocks is not None: cache_kwargs = { "batch_index": batch_index, "position_ids": position_ids, "past_seen_tokens": past_seen_tokens, } - past_key_value.write_only(key_states, value_states, self.layer_idx, cache_kwargs) + past_key_values.write_only(key_states, value_states, self.layer_idx, cache_kwargs) else: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) if num_kv_blocks is not None: attention_interface = eager_attention_forward_blockedKV @@ -251,7 +253,7 @@ def forward( num_kv_blocks=num_kv_blocks, cache_kwargs=cache_kwargs, layer_idx=self.layer_idx, - past_key_value=past_key_value, + past_key_value=past_key_values, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() @@ -290,7 +292,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 85187d33ed..15bc1a7365 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -469,7 +469,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_embeddings: Tuple[torch.Tensor, torch.Tensor], position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, @@ -502,7 +502,7 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: chunk_position_ids = position_ids if self.use_rope and self.config.attention_chunk_size: chunk_position_ids = torch.where( @@ -518,7 +518,7 @@ def forward( attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -534,7 +534,7 @@ def forward( attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_values class QEffLlama4TextDecoderLayer(Llama4TextDecoderLayer): @@ -569,7 +569,7 @@ def forward( attention_mask=attention_mask, position_embeddings=position_embeddings, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 2e8a526d79..3667af854e 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -86,7 +86,7 @@ def forward( self, hidden_states: torch.Tensor, position_ids: torch.LongTensor, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: torch.Tensor = None, batch_index: Optional[torch.LongTensor] = None, @@ -100,7 +100,7 @@ def forward( query_states = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} - if past_key_value is not None: + if past_key_values is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " @@ -111,7 +111,7 @@ def forward( attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] # kv_seq_len = past_key_value.get_seq_length(self.layer_idx) - key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs) + key_states, value_states = past_key_values.read_only(self.layer_idx, cache_kwargs=cache_kwargs) position_ids = position_ids[torch.arange(bsz), position_ids.to(torch.int32).argmax(1)].unsqueeze(1) query_states, _ = qeff_apply_rotary_pos_emb( @@ -140,7 +140,7 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) - return attn_output, past_key_value + return attn_output, past_key_values class QEffLlamaSwiftKVDecoderLayer(nn.Module): @@ -177,7 +177,7 @@ def forward( hidden_states, past_key_values = self.self_attn( hidden_states=hidden_states, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, batch_index=batch_index, diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 76a7d24c64..14aee1cf42 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -126,7 +126,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, @@ -153,12 +153,12 @@ def forward( query_states, key_states, cos_cached, sin_cached, position_ids ) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -222,7 +222,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index e59a3be534..12c8ee99fa 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -123,7 +123,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cos_cached: Optional[torch.Tensor] = None, @@ -137,7 +137,7 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " @@ -150,12 +150,12 @@ def forward( query_states, key_states, cos_cached, sin_cached, position_ids ) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -289,7 +289,7 @@ def forward( position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 642fb4bb7c..a22e7960f6 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -166,7 +166,7 @@ def forward( self, hidden_states: torch.Tensor, cross_attention_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -181,8 +181,8 @@ def forward( # elif past_key_value is not None: # Fetch old cache - key_states_old = past_key_value.layers[self.layer_idx].keys - value_states_old = past_key_value.layers[self.layer_idx].values + key_states_old = past_key_values.layers[self.layer_idx].keys + value_states_old = past_key_values.layers[self.layer_idx].values # if cross_attention_states is not None: # Compute new KV states @@ -203,8 +203,8 @@ def forward( value_states = torch.where(torch.tensor(q_len == 1), value_states_old, value_states_new) # Update the image cache - past_key_value.layers[self.layer_idx].keys = key_states - past_key_value.layers[self.layer_idx].values = value_states + past_key_values.layers[self.layer_idx].keys = key_states + past_key_values.layers[self.layer_idx].values = value_states key_states = self.k_norm(key_states) @@ -236,7 +236,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, position_embeddings: torch.Tensor = None, @@ -256,7 +256,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " @@ -269,7 +269,7 @@ def forward( query_states, key_states, cos_cached, sin_cached, position_ids ) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = { "batch_index": batch_index, "position_ids": position_ids, @@ -277,7 +277,7 @@ def forward( if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_self_attention_forward @@ -347,7 +347,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, @@ -379,7 +379,7 @@ def forward( hidden_states: torch.Tensor, cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -397,23 +397,23 @@ def forward( value_states = self.v_proj(cross_attention_states) key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] # if we have a new image + new tokens, we only computed key_states on that new image # we still update the cross key states, past_image, new_image. And use it! - key_states, value_states = past_key_value.update( + key_states, value_states = past_key_values.update( key_states, value_states, self.layer_idx, cache_kwargs, ) - elif past_key_value is not None: + elif past_key_values is not None: key_states, value_states = ( - past_key_value.layers[self.layer_idx].keys, - past_key_value.layers[self.layer_idx].values, + past_key_values.layers[self.layer_idx].keys, + past_key_values.layers[self.layer_idx].values, ) else: raise ValueError( @@ -469,7 +469,7 @@ def forward( hidden_states=hidden_states, attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, cache_position=cache_position, diff --git a/QEfficient/transformers/models/molmo/modeling_molmo.py b/QEfficient/transformers/models/molmo/modeling_molmo.py index 57f2729b91..fdb646d1fe 100644 --- a/QEfficient/transformers/models/molmo/modeling_molmo.py +++ b/QEfficient/transformers/models/molmo/modeling_molmo.py @@ -250,6 +250,8 @@ def attention( ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: B, T, C = q.size() # batch size, sequence length, d_model dtype = k.dtype + cache_position = kwargs.get("cache_position") + cos, sin = None, None # Optionally apply layer norm to keys and queries. if self.q_norm is not None and self.k_norm is not None: @@ -264,28 +266,19 @@ def attention( # shape: (B, n_kv_h, T, hs) v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) - if self.config.use_position_ids and self.config.rope: + if self.config.rope: kv_seq_len = k.shape[-2] - kv_seq_len = layer_past.get_seq_length(self.layer_id) - # Apply rotary embeddings - cos, sin = self.rotary_emb(v, seq_len=kv_seq_len) - q, k = qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, self.config) - - if not self.config.use_position_ids and self.config.rope: - kv_seq_len = k.shape[-2] - kv_seq_len = layer_past.get_seq_length(kv_seq_len, self.layer_id) + if layer_past is not None: + kv_seq_len = layer_past.get_seq_length(self.layer_id, cache_position) # Apply rotary embeddings cos, sin = self.rotary_emb(v, seq_len=kv_seq_len) q, k = qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, self.config) if layer_past is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "batch_index": batch_index, - "position_ids": position_ids, - } + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + # sin/cos are specific to RoPE models and are needed for some static cache paths. + if self.config.rope: + cache_kwargs.update({"sin": sin, "cos": cos}) if comp_ctx_lengths is not None: attention_bias = attention_bias[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_bias.shape[-1] diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index 22834d2926..fe2ebee128 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -118,7 +118,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, @@ -143,13 +143,13 @@ def forward( query_states, key_states, cos_cached, sin_cached, position_ids ) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -198,7 +198,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, diff --git a/QEfficient/transformers/models/phi/modeling_phi.py b/QEfficient/transformers/models/phi/modeling_phi.py index 82f18b7e08..9847146ada 100644 --- a/QEfficient/transformers/models/phi/modeling_phi.py +++ b/QEfficient/transformers/models/phi/modeling_phi.py @@ -66,7 +66,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, @@ -104,7 +104,7 @@ def forward( query_states = torch.cat((query_rot, query_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1) - if past_key_value is not None: + if past_key_values is not None: # Update the cache_kwargs with position_ids for Cloud AI 100 cache_kwargs = { "sin": sin, @@ -115,7 +115,7 @@ def forward( if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward attn_output, attn_weights = attention_interface( @@ -190,7 +190,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, output_attentions=output_attentions, use_cache=use_cache, diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index b18dbcd5c0..cf00205f45 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -126,7 +126,7 @@ def forward( attention_mask: Optional[torch.Tensor], batch_index: Optional[torch.LongTensor] = None, position_ids=Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, cos_cached: Optional[torch.Tensor] = None, @@ -152,7 +152,7 @@ def forward( query_states, key_states, cos_cached, sin_cached, position_ids ) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = { "batch_index": batch_index, "position_ids": position_ids, @@ -160,7 +160,7 @@ def forward( if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward attn_output, attn_weights = attention_interface( @@ -231,7 +231,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index df7421c466..a76113fd09 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -136,7 +136,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, @@ -156,12 +156,12 @@ def forward( query_states, key_states, cos_cached, sin_cached, position_ids ) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -227,7 +227,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index f333302bc0..43272a8c2d 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -559,7 +559,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, @@ -582,13 +582,13 @@ def forward( # kv_seq_len = key_states.shape[-2] # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 query_states, key_states = qeff_apply_rotary_pos_emb( query_states, key_states, cos_cached, sin_cached, position_ids[1:], self.rope_scaling["mrope_section"] ) - if past_key_value is not None: + if past_key_values is not None: if num_kv_blocks is not None: cache_kwargs = { "sin": sin_cached, @@ -597,7 +597,7 @@ def forward( "position_ids": position_ids[0], "past_seen_tokens": past_seen_tokens, } - past_key_value.write_only(key_states, value_states, self.layer_idx, cache_kwargs) + past_key_values.write_only(key_states, value_states, self.layer_idx, cache_kwargs) else: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { @@ -609,7 +609,9 @@ def forward( if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) attention_interface: Callable = eager_attention_forward @@ -622,7 +624,7 @@ def forward( num_kv_blocks=num_kv_blocks, cache_kwargs=cache_kwargs, layer_idx=self.layer_idx, - past_key_value=past_key_value, + past_key_value=past_key_values, **kwargs, ) @@ -633,7 +635,7 @@ def forward( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_values class QEffQwen2_5_VLDecoderLayer(Qwen2_5_VLDecoderLayer): @@ -684,7 +686,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index 4202f52e18..d1069f2251 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -137,7 +137,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, @@ -157,12 +157,12 @@ def forward( query_states, key_states, cos_cached, sin_cached, position_ids ) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -228,7 +228,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index fb7320ff6a..f040e5ecf0 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -183,7 +183,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, @@ -203,12 +203,12 @@ def forward( query_states, key_states, cos_cached, sin_cached, position_ids ) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -266,7 +266,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, diff --git a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py index fdbbbf05dc..a66734a8e4 100644 --- a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py +++ b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py @@ -68,7 +68,7 @@ def forward( position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, @@ -84,12 +84,12 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -157,7 +157,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_value, comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, diff --git a/QEfficient/transformers/models/t5/modeling_t5.py b/QEfficient/transformers/models/t5/modeling_t5.py index f54201465c..8fd69ffd78 100644 --- a/QEfficient/transformers/models/t5/modeling_t5.py +++ b/QEfficient/transformers/models/t5/modeling_t5.py @@ -39,7 +39,7 @@ def forward( mask=None, key_value_states=None, position_bias=None, - past_key_value=None, + past_key_values=None, layer_head_mask=None, query_length=None, use_cache=False, @@ -60,18 +60,18 @@ def forward( query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` - if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -81,7 +81,7 @@ def forward( key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -89,7 +89,7 @@ def forward( ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) @@ -108,7 +108,7 @@ def forward( position_bias = self.compute_bias( real_seq_length, key_length, device=scores.device, cache_position=cache_position ) - if past_key_value is not None: # This block is where the patch applies + if past_key_values is not None: # This block is where the patch applies position_bias = position_bias[:, :, -1:, :] # Added by patch if mask is not None: diff --git a/QEfficient/transformers/quantizers/quant_transforms.py b/QEfficient/transformers/quantizers/quant_transforms.py index 69d6380f0e..f97bfe998e 100644 --- a/QEfficient/transformers/quantizers/quant_transforms.py +++ b/QEfficient/transformers/quantizers/quant_transforms.py @@ -7,15 +7,22 @@ import torch from torch import nn +from transformers import AutoConfig from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts +from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts from QEfficient.base.pytorch_transforms import ModuleMutatorTransform from QEfficient.customop.matmulnbits import QuantLinearORT from QEfficient.transformers.quantizers.awq import WQLinear_GEMM from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ -from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear +from QEfficient.transformers.quantizers.quantizer_compressed_tensors import ( + FP8BlockWiseDequantLinear, + FP8BlockWiseDequantQwen3VLMoeTextExperts, + FP8DeQuantLinear, +) from QEfficient.transformers.quantizers.quantizer_mxfp4 import QEffMxfp4GptOssExperts from QEfficient.transformers.quantizers.quantizer_utils import ( + blockwise_dequantize, convert_moe_packed_tensors, dequantize_gptq, unpack_weights, @@ -146,3 +153,57 @@ def mutate(cls, original_module, parent_module): dequant_module.gate_up_proj_bias = original_module.gate_up_proj_bias dequant_module.down_proj_bias = original_module.down_proj_bias return dequant_module + + +class FP8BlockWiseDequantLinearToLinearTransform(ModuleMutatorTransform): + """ + Used to dequantize the weights of an FP8BlockWiseDequantLinear module and replace with a regular Linear layer + """ + + _match_class = FP8BlockWiseDequantLinear + + @classmethod + def mutate(cls, original_module, parent_module): + # -- de-quantizing the weights -- + dequant_weights = blockwise_dequantize( + original_module.weight, original_module.weight_scale_inv, original_module.weight_block_size + ) + dequant_linear_layer = nn.Linear( + original_module.in_features, original_module.out_features, bias=original_module.bias is not None + ) + dequant_linear_layer.weight = torch.nn.Parameter(dequant_weights) + if original_module.bias is not None: + dequant_linear_layer.bias = torch.nn.Parameter(original_module.bias.float()) + return dequant_linear_layer + + +class FP8BlockWiseDequantQwen3VLMoeTextExpertsToQwen3VLMoeTextExpertsTransform(ModuleMutatorTransform): + _match_class = FP8BlockWiseDequantQwen3VLMoeTextExperts + _model_type = "qwen3_vl_moe" + + @classmethod + def mutate(cls, original_module, parent_module): + config = AutoConfig.for_model(cls._model_type).text_config + config.num_experts = original_module.num_experts + config.intermediate_size = original_module.intermediate_size + config.hidden_size = original_module.hidden_size + assert original_module.act_fn.__class__.__name__ == "SiLUActivation", ( + "Only SiLU activation is supported for now." + ) + assert config.hidden_act == "silu", "expected silu act fn, something changed in transformers code" + dequant_module = Qwen3VLMoeTextExperts(config) + dequant_module.gate_up_proj = torch.nn.Parameter( + blockwise_dequantize( + original_module.gate_up_proj, + original_module.gate_up_proj_scale_inv, + original_module.weights_block_size, + ) + ) + dequant_module.down_proj = torch.nn.Parameter( + blockwise_dequantize( + original_module.down_proj, + original_module.down_proj_scale_inv, + original_module.weights_block_size, + ) + ) + return dequant_module diff --git a/QEfficient/transformers/quantizers/quantizer_awq.py b/QEfficient/transformers/quantizers/quantizer_awq.py index ef8a03521f..b7199a71ea 100644 --- a/QEfficient/transformers/quantizers/quantizer_awq.py +++ b/QEfficient/transformers/quantizers/quantizer_awq.py @@ -29,15 +29,18 @@ def post_init(self): f"Only quantization backend {AwqBackendPackingMethod.AUTOAWQ} is supported - not recognized backend {self.backend}" ) - self.version = AWQLinearVersion.from_str(self.version) + if isinstance(self.version, str): + self.version = AWQLinearVersion.from_str(self.version) if self.version not in [AWQLinearVersion.GEMM]: raise ValueError( f"Only {AWQLinearVersion.GEMM} version in supported - not recognized version {self.version}" ) - if self.do_fuse or self.fuse_max_seq_len is not None: + do_fuse = getattr(self, "do_fuse", None) + fuse_max_seq_len = getattr(self, "fuse_max_seq_len", None) + if do_fuse or fuse_max_seq_len is not None: raise ValueError( - f"fused modules are not supported, got do_fuse={self.do_fuse}, fuse_max_seq_len={self.fuse_max_seq_len}" + f"fused modules are not supported, got do_fuse={do_fuse}, fuse_max_seq_len={fuse_max_seq_len}" ) if self.bits != 4: @@ -63,6 +66,9 @@ def update_torch_dtype(self, torch_dtype): logger.warning(f"Requested dtype {torch_dtype} is not supported, overriding to None") return None + def update_dtype(self, dtype): + return self.update_torch_dtype(dtype) + def _process_model_before_weight_loading(self, model, **kwargs): self.modules_to_not_convert = get_keys_to_not_convert(model) diff --git a/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py b/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py index e7e14166d9..f7ecc5b218 100644 --- a/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py +++ b/QEfficient/transformers/quantizers/quantizer_compressed_tensors.py @@ -10,10 +10,11 @@ from typing import List import torch +from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts from transformers.quantizers.quantizer_compressed_tensors import CompressedTensorsHfQuantizer from transformers.utils.quantization_config import CompressedTensorsConfig, QuantizationConfigMixin, QuantizationMethod -from QEfficient.transformers.quantizers.quantizer_utils import get_keys_to_not_convert +from QEfficient.transformers.quantizers.quantizer_utils import blockwise_dequantize, get_keys_to_not_convert from QEfficient.utils.logging_utils import logger FP8_DTYPE = torch.float8_e4m3fn @@ -128,6 +129,118 @@ def forward(self, x): return out +class FP8BlockWiseDequantLinear(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + weight_block_size: List[int], + bias: bool = False, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight_block_size = weight_block_size + + self.register_buffer( + "weight", + torch.empty( + (out_features, in_features), dtype=FP8_DTYPE + ), # This is fixed for now and only e4m3fn quantization is prominent + ) + + if bias: + self.register_buffer( + "bias", + torch.zeros( + (out_features), + dtype=torch.float16, + ), + ) + else: + self.bias = None + + @classmethod + def for_fp8_layer_with_blocksize(cls, in_features, out_features, weight_block_size, fmt, bias): + fp8_dequant_layer = cls(in_features, out_features, weight_block_size, bias) + assert fmt == "e4m3", "e5m2 is not supposed yet!!" + assert (in_features % weight_block_size[0]) == 0 and (out_features % weight_block_size[1]) == 0, ( + "weight shape is not divisible by block sizes in either rows or columns or both dimensions, \ + got in_features: {in_features}, out_features: {out_features}, weight_block_size: {weight_block_size}!!" + ) + fp8_dequant_layer.register_buffer( + "weight_scale_inv", + torch.empty( + (out_features // weight_block_size[0], in_features // weight_block_size[1]), dtype=torch.float32 + ), + ) + return fp8_dequant_layer + + def __repr__(self): + return f"FP8BlockWiseDequantLinear(in_features={self.in_features}, out_features={self.out_features}, bias={self.bias})" + + def forward(self, x): + with torch.no_grad(): + dequantized_weights = blockwise_dequantize(self.weight, self.weight_scale_inv, self.weight_block_size) + out = torch.matmul(x.float(), dequantized_weights.T) + out = out + self.bias if self.bias is not None else out + + return out + + +class FP8BlockWiseDequantQwen3VLMoeTextExperts(torch.nn.Module): + def __init__(self, num_experts, moe_intermediate_size, hidden_size, act_fn, weights_block_size): + super().__init__() + self.num_experts = num_experts + self.intermediate_size = moe_intermediate_size + self.hidden_size = hidden_size + self.expert_dim = self.intermediate_size + self.weights_block_size = weights_block_size + r, c = weights_block_size + self.register_buffer( + "gate_up_proj", torch.empty((self.num_experts, self.hidden_size, 2 * self.expert_dim), dtype=FP8_DTYPE) + ) + self.register_buffer( + "down_proj", torch.empty((self.num_experts, self.expert_dim, self.hidden_size), dtype=FP8_DTYPE) + ) + self.register_buffer( + "gate_up_proj_scale_inv", + torch.empty((self.num_experts, self.hidden_size // r, (2 * self.expert_dim) // c), dtype=torch.float32), + ) + self.register_buffer( + "down_proj_scale_inv", + torch.empty((self.num_experts, self.expert_dim // r, self.hidden_size // c), dtype=torch.float32), + ) + self.act_fn = act_fn + + @classmethod + def for_fp8_layer_with_blocksize(cls, old_module, weight_block_size, fmt): + assert fmt == "e4m3", "e5m2 is not supposed yet!!" + fp8_experts = cls( + num_experts=old_module.num_experts, + moe_intermediate_size=old_module.intermediate_size, + hidden_size=old_module.hidden_size, + act_fn=old_module.act_fn, + weights_block_size=weight_block_size, + ) + return fp8_experts + + def forward(self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor): + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + hidden_states = hidden_states.repeat(self.num_experts, 1) + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + gate_up_proj = blockwise_dequantize(self.gate_up_proj, self.gate_up_proj_inv_scale, self.weights_block_size) + down_proj = blockwise_dequantize(self.down_proj, self.down_proj_inv_scale, self.weights_block_size) + gate_up = torch.bmm(hidden_states, gate_up_proj) + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + next_states = torch.bmm((up * self.act_fn(gate)), down_proj) + next_states = next_states.reshape(self.num_experts, batch_size, -1, self.hidden_size) + next_states = next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None] + next_states = next_states.sum(dim=0) + return next_states + + class QEffFP8Config(QuantizationConfigMixin): def __init__( self, @@ -136,6 +249,8 @@ def __init__( ignored_layers: List[str] = None, kv_cache_scheme: str = None, run_compressed: bool = False, + fmt: str = None, + weight_block_size: List[int] = None, ): self.quant_method = quant_method self.activation_scheme = activation_scheme @@ -155,6 +270,52 @@ def __init__( ) self.quant_method = QEffExtendedQuantizationMethod.FP8 + self.fmt = fmt + self.weight_block_size = weight_block_size + + +def _replace_with_fp8_dequant_linear_and_experts_if_qwen( + model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False +): + current_key_name = [] if current_key_name is None else current_key_name + + for name, child_module in model.named_children(): + current_key_name.append(name) + + if isinstance(child_module, torch.nn.Linear) and name not in (modules_to_not_convert or []): + current_key_name_str = ".".join(current_key_name) + if not any(key in current_key_name_str for key in (modules_to_not_convert or [])): + model._modules[name] = FP8BlockWiseDequantLinear.for_fp8_layer_with_blocksize( + child_module.in_features, + child_module.out_features, + quantization_config.weight_block_size, + quantization_config.fmt, + child_module.bias is not None, + ) + has_been_replaced = True + + if isinstance(child_module, Qwen3VLMoeTextExperts) and name not in (modules_to_not_convert or []): + # Replace the MoE experts + current_key_name_str = ".".join(current_key_name) + if not any(key in current_key_name_str for key in (modules_to_not_convert or [])): + model._modules[name] = FP8BlockWiseDequantQwen3VLMoeTextExperts.for_fp8_layer_with_blocksize( + child_module, + quantization_config.weight_block_size, + quantization_config.fmt, + ) + has_been_replaced = True + + if len(list(child_module.children())) > 0: + _, has_been_replaced = _replace_with_fp8_dequant_linear_and_experts_if_qwen( + child_module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + + current_key_name.pop(-1) + return model, has_been_replaced class QEffFP8Quantizer(CompressedTensorsHfQuantizer): @@ -196,6 +357,12 @@ def _process_model_before_weight_loading(self, model, **kwargs): f"activations quantization strategy = {self.quantization_config.activation_scheme}, will be ignored and the layers will be run with de-quantized weights" ) + if self.quantization_config.weight_block_size is not None: + model, has_been_replaced = _replace_with_fp8_dequant_linear_and_experts_if_qwen( + model, self.modules_to_not_convert, quantization_config=self.quantization_config + ) + return + # -- Defining local method as it uses lot of local variables -- def replace_linear_with_fp8_dequant_layer(module): for name, child_module in module.named_children(): @@ -218,7 +385,7 @@ def _process_model_after_weight_loading(self, model, **kwargs): def update_missing_keys_after_loading(self, model, missing_keys: List[str], prefix: str) -> List[str]: return missing_keys - def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str) -> List[str]: + def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str = None) -> List[str]: return unexpected_keys @@ -399,5 +566,5 @@ def _process_model_after_weight_loading(self, model, **kwargs): def update_missing_keys_after_loading(self, model, missing_keys: List[str], prefix: str) -> List[str]: return missing_keys - def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str) -> List[str]: + def update_unexpected_keys(self, model, unexpected_keys: List[str], prefix: str = None) -> List[str]: return unexpected_keys diff --git a/QEfficient/transformers/quantizers/quantizer_mxfp4.py b/QEfficient/transformers/quantizers/quantizer_mxfp4.py index 2ffba1beaa..44c255feb5 100644 --- a/QEfficient/transformers/quantizers/quantizer_mxfp4.py +++ b/QEfficient/transformers/quantizers/quantizer_mxfp4.py @@ -105,6 +105,9 @@ def update_torch_dtype(self, torch_dtype): logger.warning(f"Requested dtype {torch_dtype} is not supported, overriding to None") return None + def update_dtype(self, dtype): + return self.update_torch_dtype(dtype) + def _process_model_before_weight_loading( self, model: torch.nn.Module, diff --git a/QEfficient/transformers/quantizers/quantizer_utils.py b/QEfficient/transformers/quantizers/quantizer_utils.py index 424692d087..4060a162d4 100644 --- a/QEfficient/transformers/quantizers/quantizer_utils.py +++ b/QEfficient/transformers/quantizers/quantizer_utils.py @@ -7,6 +7,7 @@ import copy import math +from typing import List import torch from torch import nn @@ -446,3 +447,26 @@ def convert_moe_packed_tensors( out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) out = out.to(dtype).permute(0, 2, 1).contiguous() return out + + +def blockwise_dequantize( + quantized: torch.Tensor, + scales: torch.Tensor, + block_size: List[int] = None, + **kwargs, +) -> dict[str, torch.Tensor]: + rows, cols = quantized.shape[-2:] + if block_size is None: + block_size = (quantized.shape[-2], quantized.shape[-1]) + + block_m, block_n = block_size + + if rows % block_m != 0 or cols % block_n != 0: + raise ValueError(f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n}).") + quantized = quantized.to(scales.dtype) + reshaped = quantized.reshape(-1, rows // block_m, block_m, cols // block_n, block_n) + expanded_scales = scales.reshape(-1, rows // block_m, cols // block_n) + expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2) + dequantized = reshaped * expanded_scales + + return dequantized.reshape(quantized.shape) diff --git a/pyproject.toml b/pyproject.toml index d5b38cdcfd..62636f96ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ ] requires-python = ">=3.8,<3.13" dependencies = [ - "transformers==4.55.0", + "transformers==4.57.3", "diffusers== 0.35.1", "huggingface-hub==0.34.0", "hf_transfer==0.1.9", @@ -55,7 +55,7 @@ dependencies = [ ] [project.optional-dependencies] -test = ["pytest","pytest-mock"] +test = ["pytest","pytest-mock","pytest-xdist"] docs = ["Sphinx==7.1.2","sphinx-rtd-theme==2.0.0","myst-parser==3.0.1","sphinx-multiversion"] quality = ["black", "ruff", "hf_doc_builder@git+https://github.com/huggingface/doc-builder.git"] diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index 2e75c7e696..6ea8a98b62 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -78,7 +78,7 @@ pipeline { mkdir -p $PWD/Non_qaic_feature && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_qaic_feature && - pytest tests -m '(not cli) and (on_qaic) and (feature) and (not nightly) and (not multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --ignore tests/unit_test --junitxml=tests/tests_log2_feature.xml --durations=10 && + pytest tests -m '(not cli) and (on_qaic) and (feature) and (not nightly) and (not multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --ignore tests/unit_test --ignore tests/transformers/sampler --junitxml=tests/tests_log2_feature.xml --durations=10 && junitparser merge tests/tests_log2_feature.xml tests/tests_log.xml && deactivate" ''' diff --git a/tests/conftest.py b/tests/conftest.py index d1f553cda3..7e21127482 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,55 @@ from QEfficient.utils.constants import QEFF_MODELS_DIR from QEfficient.utils.logging_utils import logger +_QUICKCHECK_FILE = "tests/unit_test/models/test_model_quickcheck.py" +_QUICKCHECK_SUMMARY = {} +_QUICKCHECK_META = { + "test_causal_lm_cpu_runtime_parity_with_api_runner": ( + "Causal LM", + "Full parity: HF PyTorch vs QEff PyTorch vs ORT tokens", + ), + "test_vlm_text_side_runtime_parity_and_full_export": ( + "VLM", + "Text-side full parity + full VLM export smoke", + ), + "test_vlm_export_smoke_additional_models": ( + "VLM", + "Export smoke with text-side fallback when needed", + ), + "test_text_embedding_cpu_parity_and_export": ( + "Text Embedding", + "Tensor parity: HF vs QEff PyTorch vs ORT", + ), + "test_audio_embedding_ctc_cpu_parity_and_export": ( + "Audio CTC", + "Logits parity: HF vs ORT + export", + ), + "test_seq_classification_cpu_parity_and_export": ( + "Sequence Classification", + "Logits parity: HF vs QEff PyTorch vs ORT", + ), + "test_whisper_export_smoke": ( + "Whisper", + "Export smoke + retained-state outputs check", + ), + "test_causal_subfunction_export_smoke": ( + "Causal LM", + "Subfunction export check (with/without QEffGPT2Block)", + ), + "test_causal_subfunction_export_smoke_all_models": ( + "Causal LM", + "Full parity: HF PyTorch vs QEff PyTorch vs ORT tokens (subfunctions)", + ), + "test_prefix_caching_continuous_batching_export_and_ort_smoke": ( + "Prefix Caching", + "Continuous-batching export structural checks", + ), + "test_awq_export_smoke": ( + "AWQ", + "Export smoke + MatMulNBits presence check", + ), +} + def qeff_models_clean_up(): if os.path.exists(QEFF_MODELS_DIR): @@ -42,3 +91,32 @@ def pytest_sessionfinish(session, exitstatus): if inside_worker is None: qeff_models_clean_up() logger.info("...PYTEST Session Ended.") + + +def pytest_runtest_logreport(report): + if _QUICKCHECK_FILE not in report.nodeid: + return + + if report.when == "call": + _QUICKCHECK_SUMMARY[report.nodeid] = report.outcome + return + + if report.when == "setup" and report.outcome == "skipped": + _QUICKCHECK_SUMMARY.setdefault(report.nodeid, report.outcome) + + +def pytest_terminal_summary(terminalreporter): + if not _QUICKCHECK_SUMMARY: + return + + terminalreporter.section("Quickcheck Coverage Summary", sep="=") + header = f"{'Status':7} {'Test Case':58} {'Category':24} Validation" + terminalreporter.write_line(header) + terminalreporter.write_line("-" * len(header)) + + for nodeid in sorted(_QUICKCHECK_SUMMARY): + test_case = nodeid.split("::", 1)[1] + base_name = test_case.split("[", 1)[0] + category, validation = _QUICKCHECK_META.get(base_name, ("Other", "N/A")) + status = _QUICKCHECK_SUMMARY[nodeid].upper() + terminalreporter.write_line(f"{status:7} {test_case:58} {category:24} {validation}") diff --git a/tests/diffusers/test_flux.py b/tests/diffusers/test_flux.py index 31c62e336e..6189dd01c0 100644 --- a/tests/diffusers/test_flux.py +++ b/tests/diffusers/test_flux.py @@ -28,6 +28,7 @@ # Test Configuration for 256x256 resolution with 2 layers # update mad tolerance CONFIG_PATH = "tests/diffusers/flux_test_config.json" INITIAL_TEST_CONFIG = load_json(CONFIG_PATH) +TEST_SEED = 42 def flux_pipeline_call_with_mad_validation( @@ -164,7 +165,7 @@ def flux_pipeline_call_with_mad_validation( # Allocate output buffer for transformer output_buffer = { - "output": np.random.rand(batch_size, cl, pipeline.transformer.model.config.in_channels).astype(np.float32), + "output": np.zeros((batch_size, cl, pipeline.transformer.model.config.in_channels), dtype=np.float32), } pipeline.transformer.qpc_session.set_buffers(output_buffer) @@ -276,7 +277,7 @@ def flux_pipeline_call_with_mad_validation( ) # Allocate output buffer for VAE decoder - output_buffer = {"sample": np.random.rand(batch_size, 3, height, width).astype(np.float32)} + output_buffer = {"sample": np.zeros((batch_size, 3, height, width), dtype=np.float32)} pipeline.vae_decode.qpc_session.set_buffers(output_buffer) # MAD Validation for VAE @@ -315,6 +316,9 @@ def flux_pipeline_call_with_mad_validation( @pytest.fixture(scope="session") def flux_pipeline(): """Setup Flux test pipelines with random-initialized (dummy) weights.""" + torch.manual_seed(TEST_SEED) + np.random.seed(TEST_SEED) + config = INITIAL_TEST_CONFIG["model_setup"] model_id = "black-forest-labs/FLUX.1-schnell" @@ -351,6 +355,10 @@ def flux_pipeline(): tokenizer_2=tokenizer_2, transformer=transformer, ) + vae.eval() + transformer.eval() + text_encoder.eval() + text_encoder_2.eval() # Use QEff wrapper on a copy of the random-init reference model. import copy @@ -387,7 +395,7 @@ def test_flux_pipeline(flux_pipeline): max_sequence_length = config["pipeline_params"]["max_sequence_length"] # Generate with MAD validation - generator = torch.manual_seed(42) + generator = torch.Generator(device="cpu").manual_seed(TEST_SEED) start_time = time.time() try: diff --git a/tests/diffusers/test_wan.py b/tests/diffusers/test_wan.py index 19f8bd0b36..c21be7275d 100644 --- a/tests/diffusers/test_wan.py +++ b/tests/diffusers/test_wan.py @@ -34,6 +34,7 @@ # Test Configuration for 48 x 64 resolution with 1 layer CONFIG_PATH = "tests/diffusers/wan_test_config.json" INITIAL_TEST_CONFIG = load_json(CONFIG_PATH) +TEST_SEED = 42 def wan_pipeline_call_with_mad_validation( @@ -194,11 +195,10 @@ def wan_pipeline_call_with_mad_validation( ) output_buffer = { - "output": np.random.rand( - batch_size, - pipeline.cl, - constants.WAN_DIT_OUT_CHANNELS, - ).astype(np.int32), + "output": np.zeros( + (batch_size, pipeline.cl, constants.WAN_DIT_OUT_CHANNELS), + dtype=np.int32, + ), } pipeline.transformer.qpc_session.set_buffers(output_buffer) transformer_perf = [] @@ -261,7 +261,7 @@ def wan_pipeline_call_with_mad_validation( # Prepare inputs for QAIC inference inputs_aic = { - "hidden_states": latents.detach().numpy(), + "hidden_states": latent_model_input.detach().numpy(), "encoder_hidden_states": encoder_hidden_states.detach().numpy(), "rotary_emb": rotary_emb.detach().numpy(), "temb": temb.detach().numpy(), @@ -273,7 +273,7 @@ def wan_pipeline_call_with_mad_validation( noise_pred_torch = pytorch_current_model( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=pytorch_prompt_embeds, + encoder_hidden_states=prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] @@ -330,7 +330,7 @@ def wan_pipeline_call_with_mad_validation( video_torch = pytorch_pipeline.vae.decode(latents, return_dict=False)[0] # Allocate output buffer for VAE decoder - output_buffer = {"sample": np.random.rand(batch_size, 3, num_frames, height, width).astype(np.int32)} + output_buffer = {"sample": np.zeros((batch_size, 3, num_frames, height, width), dtype=np.int32)} pipeline.vae_decoder.qpc_session.set_buffers(output_buffer) # Run VAE decoder inference and measure time @@ -369,6 +369,9 @@ def wan_pipeline_call_with_mad_validation( @pytest.fixture(scope="session") def wan_pipeline(): """Build the WAN pipeline with random weights/ dummy config.""" + torch.manual_seed(TEST_SEED) + np.random.seed(TEST_SEED) + config = INITIAL_TEST_CONFIG["model_setup"] model_id = "Wan-AI/Wan2.2-T2V-A14B-Diffusers" pipe_cfg = WanPipeline.load_config(model_id) @@ -411,6 +414,10 @@ def wan_pipeline(): boundary_ratio=pipe_cfg.get("boundary_ratio"), expand_timesteps=pipe_cfg.get("expand_timesteps", False), ) + vae.eval() + transformer_high.eval() + transformer_low.eval() + text_encoder.eval() pytorch_pipeline_copy = copy.deepcopy(pytorch_pipeline) pipeline = QEffWanPipeline(pytorch_pipeline_copy) @@ -448,7 +455,7 @@ def test_wan_pipeline(wan_pipeline): num_frames = config["model_setup"]["num_frames"] # Generate with MAD validation - generator = torch.manual_seed(42) + generator = torch.Generator(device="cpu").manual_seed(TEST_SEED) start_time = time.time() try: diff --git a/tests/diffusers/test_wan_i2v.py b/tests/diffusers/test_wan_i2v.py index 4c86eb2c41..8f4e56272e 100644 --- a/tests/diffusers/test_wan_i2v.py +++ b/tests/diffusers/test_wan_i2v.py @@ -34,6 +34,7 @@ # Test Configuration for I2V with dynamic sizing CONFIG_PATH = "tests/diffusers/wan_i2v_test_config.json" INITIAL_TEST_CONFIG = load_json(CONFIG_PATH) +TEST_SEED = 42 def prepare_test_image_with_dynamic_sizing(pipeline, config): @@ -257,11 +258,10 @@ def wan_i2v_pipeline_call_with_mad_validation( ) output_buffer = { - "output": np.random.rand( - batch_size, - pipeline.cl, - constants.WAN_DIT_OUT_CHANNELS, - ).astype(np.int32), + "output": np.zeros( + (batch_size, pipeline.cl, constants.WAN_DIT_OUT_CHANNELS), + dtype=np.int32, + ), } pipeline.transformer.qpc_session.set_buffers(output_buffer) transformer_perf = [] @@ -408,7 +408,7 @@ def wan_i2v_pipeline_call_with_mad_validation( video_torch = pytorch_pipeline.vae.decode(latents, return_dict=False)[0] # Allocate output buffer for VAE decoder - output_buffer = {"sample": np.random.rand(batch_size, 3, num_frames, height, width).astype(np.int32)} + output_buffer = {"sample": np.zeros((batch_size, 3, num_frames, height, width), dtype=np.int32)} pipeline.vae_decoder.qpc_session.set_buffers(output_buffer) # Run VAE decoder inference and measure time @@ -448,6 +448,9 @@ def wan_i2v_pipeline_call_with_mad_validation( @pytest.fixture(scope="session") def wan_i2v_pipeline(): """Build the WAN I2V pipeline with random weights/dummy config.""" + torch.manual_seed(TEST_SEED) + np.random.seed(TEST_SEED) + config = INITIAL_TEST_CONFIG["model_setup"] model_id = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" pipe_cfg = WanImageToVideoPipeline.load_config(model_id) @@ -490,6 +493,10 @@ def wan_i2v_pipeline(): boundary_ratio=pipe_cfg.get("boundary_ratio"), expand_timesteps=pipe_cfg.get("expand_timesteps", False), ) + vae.eval() + transformer_high.eval() + transformer_low.eval() + text_encoder.eval() import copy @@ -535,7 +542,7 @@ def test_wan_i2v_pipeline(wan_i2v_pipeline): num_frames = config["model_setup"]["num_frames"] # Generate with MAD validation - generator = torch.manual_seed(42) + generator = torch.Generator(device="cpu").manual_seed(TEST_SEED) start_time = time.time() try: diff --git a/tests/diffusers/wan_test_config.json b/tests/diffusers/wan_test_config.json index 0177980472..1145078dc2 100644 --- a/tests/diffusers/wan_test_config.json +++ b/tests/diffusers/wan_test_config.json @@ -10,7 +10,7 @@ "tolerances": { "transformer_high": 0.1, "transformer_low": 0.1, - "vae_decoder" : 1 + "vae_decoder" : 1.3 } }, "pipeline_params": { diff --git a/tests/transformers/models/test_speech_seq2seq_models.py b/tests/transformers/models/test_speech_seq2seq_models.py index 774802c83e..80ab7f2062 100644 --- a/tests/transformers/models/test_speech_seq2seq_models.py +++ b/tests/transformers/models/test_speech_seq2seq_models.py @@ -7,7 +7,6 @@ import json import os -from importlib import reload from typing import List, Optional import numpy as np @@ -15,7 +14,6 @@ import onnxruntime import pytest import torch -import transformers from datasets import load_dataset from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor @@ -89,7 +87,6 @@ def run_seq2seq_pytorch_hf( ) # TODO: temporary hack to nullify effect of KVCacheTransform add this as setup_module in pytest - reload(transformers.cache_utils) # encoder run outputs = model(**model_inputs) @@ -354,6 +351,7 @@ def check_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100( @pytest.mark.on_qaic @pytest.mark.llm_model +# @pytest.mark.skip(reason="Whisper is failing with the latest transformers v4.57.3") @pytest.mark.parametrize("model_name", test_models) def test_seq2seq_pytorch_vs_kv_vs_ort_vs_ai100(model_name): """ diff --git a/tests/unit_test/models/test_model_quickcheck.py b/tests/unit_test/models/test_model_quickcheck.py index 398fd69240..c0b5c20525 100644 --- a/tests/unit_test/models/test_model_quickcheck.py +++ b/tests/unit_test/models/test_model_quickcheck.py @@ -76,6 +76,14 @@ "olmo2": "hf-internal-testing/tiny-random-Olmo2ForCausalLM", "gpt_oss": "tiny-random/gpt-oss-bf16", } +CAUSAL_MULTI_SUBFUNCTION_MODEL_TYPES = { + "codegen", + "phi", + "starcoder2", + "mixtral", + "gpt_oss", + # "granitemoe" is intentionally not listed in CAUSAL_RUNTIME_MODEL_IDS yet. +} VLM_TEXT_RUNTIME_MODEL_ID = "tiny-random/gemma-3" VLM_EXPORT_MODEL_IDS = { @@ -170,6 +178,22 @@ def _exported_onnx_path(export_result) -> Path: return onnx_path +def _count_decoder_block_subfunctions(onnx_model, qeff_model) -> int: + get_submodules = getattr(qeff_model.model, "get_submodules_for_export", None) + if not callable(get_submodules): + return 0 + + submodules = get_submodules() + if not submodules: + return 0 + + if not isinstance(submodules, (set, list, tuple)): + submodules = [submodules] + + block_names = {module.__name__ for module in submodules if hasattr(module, "__name__")} + return sum(any(block_name in func.name for block_name in block_names) for func in onnx_model.functions) + + def _assert_has_retained_state_outputs(onnx_path: Path) -> None: onnx_model = onnx.load(onnx_path, load_external_data=False) retained_outputs = [output.name for output in onnx_model.graph.output if output.name.endswith("_RetainedState")] @@ -214,7 +238,7 @@ def _export_vlm_with_text_fallback(model_id: str, out_dir: Path) -> Path: try: config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) model_type = getattr(config, "model_type", "") - use_text_only_first = model_type in {"qwen2_5_vl", "internvl_chat"} + use_text_only_first = model_type in {"qwen2_5_vl", "qwen2_5_vl_text", "internvl_chat"} if not use_text_only_first: try: @@ -224,7 +248,7 @@ def _export_vlm_with_text_fallback(model_id: str, out_dir: Path) -> Path: pass try: - if model_type == "qwen2_5_vl" and getattr(config, "text_config", None) is not None: + if model_type in {"qwen2_5_vl", "qwen2_5_vl_text"} and getattr(config, "text_config", None) is not None: qwen2_cfg_dict = config.text_config.to_dict() qwen2_cfg_dict["model_type"] = "qwen2" qwen2_allowed_keys = set(Qwen2Config().to_dict().keys()) @@ -504,15 +528,71 @@ def test_causal_compile_with_subfunctions_all_models(model_type, model_id, tmp_p ids=sorted(CAUSAL_RUNTIME_MODEL_IDS), ) def test_causal_subfunction_export_smoke_all_models(model_type, model_id, tmp_path): - del model_type + if model_type == "gpt_oss": + pytest.skip("Subfunction runtime parity is currently excluded for gpt_oss in quickcheck.") + + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + if hasattr(tokenizer, "model_input_names"): + tokenizer.model_input_names = ["input_ids", "attention_mask"] + prompt = ["hello world"] + prompt_len = 8 + ctx_len = 12 + + model_hf = AutoModelForCausalLM.from_pretrained( + model_id, + **MODEL_KWARGS, + low_cpu_mem_usage=False, + trust_remote_code=True, + torch_dtype=torch.float32, + ) + model_hf.eval() + + api_runner = ApiRunner( + batch_size=1, + tokenizer=tokenizer, + config=model_hf.config, + prompt=prompt, + prompt_len=prompt_len, + ctx_len=ctx_len, + full_batch_size=None, + ) + + hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) + qeff_model = QEFFAutoModelForCausalLM(model_hf) + kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) + onnx_path = _exported_onnx_path(qeff_model.export(tmp_path / "with-subfunctions-all", use_onnx_subfunctions=True)) + ort_tokens = api_runner.run_kv_model_on_ort(str(onnx_path)) + + assert np.array_equal(hf_tokens, kv_tokens.squeeze(0)) + assert np.array_equal(kv_tokens, ort_tokens) + + +@pytest.mark.llm_model +@pytest.mark.parametrize( + ("model_type", "model_id"), + sorted(CAUSAL_RUNTIME_MODEL_IDS.items()), + ids=sorted(CAUSAL_RUNTIME_MODEL_IDS), +) +def test_causal_subfunction_count_with_onnx_subfunctions(model_type, model_id, tmp_path): try: qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True) except Exception as exc: _skip_on_model_fetch_error(exc, model_id) - onnx_path = _exported_onnx_path(qeff_model.export(tmp_path / "with-subfunctions-all", use_onnx_subfunctions=True)) + onnx_path = _exported_onnx_path( + qeff_model.export(tmp_path / f"subfunction-count-{model_type}", use_onnx_subfunctions=True) + ) onnx_model = onnx.load(onnx_path, load_external_data=False) - assert len(onnx_model.functions) > 0 + subfunction_count = _count_decoder_block_subfunctions(onnx_model, qeff_model) + + if model_type in CAUSAL_MULTI_SUBFUNCTION_MODEL_TYPES: + assert subfunction_count > 1, ( + f"{model_type} expected multiple decoder-block subfunctions (>1), but found {subfunction_count}" + ) + else: + assert subfunction_count == 1, ( + f"{model_type} expected a single decoder-block subfunction (1), but found {subfunction_count}" + ) @pytest.mark.llm_model