From 4823c0c302d3658cfdc3bff50a6c9bb5d74830d5 Mon Sep 17 00:00:00 2001 From: Vinayak Baddi <68580231+vbaddi@users.noreply.github.com> Date: Tue, 21 Apr 2026 23:29:14 -0400 Subject: [PATCH 1/2] feat(rope_fix): Hoist layer-invariant RoPE indexing out of decoder subfunctions for cached text models (#928) This change moves layer-invariant RoPE cos/sin indexing out of repeated decoder-layer subfunctions and into model-level forward paths. For cached decoder models, we were repeatedly doing: ``` cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) ``` inside each decoder attention block. With ONNX subfunctions enabled, that indexing becomes part of the exported repeated subfunction body and contributes to the on-device regression we observed after the single-subfunction Rope Fix work #880 . This patch hoists that work once per forward pass and passes the already-shaped cos/sin tensors into each decoder layer. Applied the refactor to the applicable QEff model families that thread static cached RoPE tensors through repeated decoder layers, including: - Llama - Llama SwiftKV - Gemma - Gemma2 - Mistral - Falcon - GPT-OSS - Granite - GraniteMoE - Mllama text path - Mixtral - Olmo2 - Phi3 - Qwen2 - Qwen3 - Qwen3 MoE - Qwen2.5 VL text path - Qwen3 VL text path - Qwen3 VL MoE text path For the Qwen VL text towers, the same idea is applied to the indexed/interleaved MRoPE preparation: the already-indexed cos/sin tensors are prepared once before the decoder-layer loop and reused across layers. Added a TinyLlama regression test to assert that export with subfunctions still produces a single decoder-layer ONNX function. Verified: `python -m pytest -q tests/unit_test/models/test_model_quickcheck.py -n auto` --------- Signed-off-by: vbaddi Signed-off-by: Rishin Raj Co-authored-by: Rishin Raj --- .../models/falcon/modeling_falcon.py | 23 +--- .../models/gemma/modeling_gemma.py | 26 +--- .../models/gemma2/modeling_gemma2.py | 26 +--- .../models/gpt_oss/modeling_gpt_oss.py | 46 +++---- .../models/granite/modeling_granite.py | 26 +--- .../models/granitemoe/modeling_granitemoe.py | 92 +++++++------- .../models/llama/modeling_llama.py | 65 ++++------ .../llama_swiftkv/modeling_llama_swiftkv.py | 29 +++-- .../models/mistral/modeling_mistral.py | 25 +--- .../models/mixtral_moe/modeling_mixtral.py | 25 +--- .../models/mllama/modeling_mllama.py | 25 +--- .../models/olmo2/modeling_olmo2.py | 27 +--- .../transformers/models/phi3/modeling_phi3.py | 25 +--- .../models/qwen2/modeling_qwen2.py | 27 +--- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 36 ++---- .../models/qwen3/modeling_qwen3.py | 26 +--- .../models/qwen3_moe/modeling_qwen3_moe.py | 25 +--- .../models/qwen3_vl/modeling_qwen3_vl.py | 109 ++++++++-------- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 116 +++++++++--------- .../models/test_single_subfunction.py | 22 +++- 20 files changed, 318 insertions(+), 503 deletions(-) diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 26080a59a..8f8d7e587 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -60,7 +60,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def qeff_apply_rotary_pos_emb(q, k, cos, sin): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -68,22 +68,9 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -127,7 +114,7 @@ def forward( value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos_cached, sin_cached, position_ids) + query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos_cached, sin_cached) if layer_past is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -301,6 +288,8 @@ def forward( all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) for i, block in enumerate(self.h): if output_hidden_states: @@ -319,8 +308,8 @@ def forward( output_attentions=output_attentions, alibi=alibi, cache_position=cache_position, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, ) hidden_states = outputs[0] diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 0d740c717..3b326a98e 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -56,7 +56,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def qeff_apply_rotary_pos_emb(q, k, cos, sin): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -64,22 +64,9 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -138,10 +125,7 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos_cached, sin_cached, position_ids - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) if past_key_values is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -295,6 +279,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -309,8 +295,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, **kwargs, ) diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index ac6de7de4..b2c6c710d 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -59,7 +59,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def qeff_apply_rotary_pos_emb(q, k, cos, sin): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -67,22 +67,9 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -145,10 +132,7 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos_cached, sin_cached, position_ids - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -339,6 +323,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) for decoder_layer in self.layers: if output_hidden_states: @@ -354,8 +340,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, **kwargs, ) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index d0b928353..40e7ac23a 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -535,7 +535,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def qeff_apply_rotary_pos_emb(q, k, cos, sin): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). Explanation: @@ -552,25 +552,10 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - mrope_section(`List(int)`): - Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -748,9 +733,7 @@ def forward( hidden_shape = (*input_shape, -1, self.head_dim) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos_cached, sin_cached, position_ids - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -832,9 +815,7 @@ def forward( hidden_shape = (*input_shape, -1, self.head_dim) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos_cached, sin_cached, position_ids - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -912,9 +893,8 @@ def forward( query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos_cached, sin_cached, position_ids - ) + past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0 + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -1071,6 +1051,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) for decoder_layer in self.layers: if output_hidden_states: @@ -1086,8 +1068,8 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, sliding_mask=sliding_mask, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, **kwargs, ) hidden_states = layer_outputs[0] @@ -1112,8 +1094,8 @@ def forward( class QEffGptOssModel(GptOssModel): def __qeff_init__(self): self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) - self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) - self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) def forward( self, @@ -1171,6 +1153,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) for decoder_layer in self.layers: if output_hidden_states: @@ -1187,8 +1171,8 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, sliding_mask=sliding_mask, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, **kwargs, ) hidden_states = layer_outputs[0] diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index 81aa19294..56b6532f1 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -54,7 +54,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def qeff_apply_rotary_pos_emb(q, k, cos, sin): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -62,22 +62,9 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -131,10 +118,7 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos_cached, sin_cached, position_ids - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -300,6 +284,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -315,8 +301,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, **kwargs, ) diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 40359e7c8..3ff2161a2 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -65,8 +65,6 @@ def qeff_apply_rotary_pos_emb( k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, - position_ids: torch.Tensor, - unsqueeze_dim: int = 1, ): """Applies Rotary Position Embedding to the query and key tensors. @@ -75,22 +73,9 @@ def qeff_apply_rotary_pos_emb( k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -126,34 +111,45 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos_cached, sin_cached, position_ids - ) - if past_key_values is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin_cached, - "cos": cos_cached, - "cache_position": cache_position, - "batch_index": batch_index, - "position_ids": position_ids, - } - if comp_ctx_lengths is not None: - attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] - cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface = eager_attention_forward - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - scaling=self.scaling, - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) + past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0 + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) + use_blocking = blocking_config is not None and (blocking_config.mode != BlockingMode.NONE) + if use_blocking: + attn_output, attn_weights = generic_blocked_attention_interface( + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + scaling=self.scaling, + layer_idx=self.layer_idx, + past_key_value=past_key_values, + blocking_config=blocking_config, + comp_ctx_length=comp_ctx_lengths, + batch_index=batch_index, + position_ids=position_ids, + past_seen_tokens=past_seen_tokens, + ) + else: + key_states, value_states, _ = past_key_value_update( + module=self, + key=key_states, + value=value_states, + attention_mask=attention_mask, + past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + position_ids=position_ids, + ) + attn_output, attn_weights = eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + ) attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) @@ -339,6 +335,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -355,8 +353,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, ) else: layer_outputs = decoder_layer( @@ -369,8 +367,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, ) hidden_states = layer_outputs[0] diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 00f97e24d..760def7af 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -55,7 +55,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def qeff_apply_rotary_pos_emb(q, k, cos, sin): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -63,22 +63,9 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -216,30 +203,26 @@ def forward( value_states = self.v_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos_cached, sin_cached, position_ids - ) - - if past_key_values is not None: - if num_kv_blocks is not None: - cache_kwargs = { - "batch_index": batch_index, - "position_ids": position_ids, - "past_seen_tokens": past_seen_tokens, - } - past_key_values.write_only(key_states, value_states, self.layer_idx, cache_kwargs) - else: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} - if comp_ctx_lengths is not None: - attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] - cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_values.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - if num_kv_blocks is not None: - attention_interface = eager_attention_forward_blockedKV + past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0 + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) + use_blocking = blocking_config is not None and (blocking_config.mode != BlockingMode.NONE) + if use_blocking: + attn_output, attn_weights = generic_blocked_attention_interface( + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + scaling=self.scaling, + layer_idx=self.layer_idx, + past_key_value=past_key_values, + blocking_config=blocking_config, + comp_ctx_length=comp_ctx_lengths, + batch_index=batch_index, + position_ids=position_ids, + past_seen_tokens=past_seen_tokens, + ) else: attention_interface = eager_attention_forward @@ -367,6 +350,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -381,8 +366,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, **kwargs, ) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 3667af854..65d1edc2a 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -98,6 +98,7 @@ def forward( query = self.q_proj_swiftkv(hidden_states) # Reshape the query, key, and value tensors. query_states = query.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + token_index = position_ids.to(torch.int32).argmax(1) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} if past_key_values is not None: @@ -113,10 +114,14 @@ def forward( # kv_seq_len = past_key_value.get_seq_length(self.layer_idx) key_states, value_states = past_key_values.read_only(self.layer_idx, cache_kwargs=cache_kwargs) - position_ids = position_ids[torch.arange(bsz), position_ids.to(torch.int32).argmax(1)].unsqueeze(1) - query_states, _ = qeff_apply_rotary_pos_emb( - query_states, torch.empty_like(query_states), cos_cached, sin_cached, position_ids - ) + position_ids = position_ids[torch.arange(bsz), token_index].unsqueeze(1) + if cos_cached.dim() == 2: + cos = cos_cached[position_ids].unsqueeze(1) + sin = sin_cached[position_ids].unsqueeze(1) + else: + cos = cos_cached[torch.arange(bsz), :, token_index, :].unsqueeze(2) + sin = sin_cached[torch.arange(bsz), :, token_index, :].unsqueeze(2) + query_states, _ = qeff_apply_rotary_pos_emb(query_states, torch.empty_like(query_states), cos, sin) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -230,6 +235,8 @@ def _run_swiftkv_layers( causal_mask, batch_index, ) -> torch.Tensor: + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): layer = self.layers[layer_idx] hidden_states = layer( @@ -239,8 +246,8 @@ def _run_swiftkv_layers( comp_ctx_lengths, causal_mask, batch_index, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, ) hidden_states = self.norm(hidden_states) @@ -350,6 +357,8 @@ def forward( ) if position_ids is None: position_ids = cache_position.unsqueeze(0) + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) causal_mask = self._update_causal_mask( None, inputs_embeds, cache_position, position_ids, past_key_values, False @@ -369,8 +378,8 @@ def forward( batch_index=batch_index, output_attentions=False, use_cache=True, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, ) bsz, q_len, _ = hidden_states.size() @@ -396,9 +405,7 @@ def forward( ) # kv_seq_len = past_key_values.get_seq_length(self_attn.layer_idx) - _, key_states = qeff_apply_rotary_pos_emb( - torch.empty_like(key_states), key_states, self.cos_cached, self.sin_cached, position_ids - ) + _, key_states = qeff_apply_rotary_pos_emb(torch.empty_like(key_states), key_states, cos, sin) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} past_key_values.write_only(key_states, value_states, self_attn.layer_idx, cache_kwargs) diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 14aee1cf4..a6ca8fca2 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -59,7 +59,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def qeff_apply_rotary_pos_emb(q, k, cos, sin): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -67,22 +67,9 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -149,9 +136,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos_cached, sin_cached, position_ids - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) if past_key_values is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -312,6 +297,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) for decoder_layer in self.layers: if output_hidden_states: @@ -327,8 +314,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, **kwargs, ) diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index 12c8ee99f..00e6a9f4a 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -61,7 +61,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def qeff_apply_rotary_pos_emb(q, k, cos, sin): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -69,22 +69,9 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -146,9 +133,7 @@ def forward( ) # kv_seq_len = past_key_value.get_seq_length(self.layer_idx) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos_cached, sin_cached, position_ids - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) if past_key_values is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -378,6 +363,8 @@ def forward( # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) # decoder layers all_hidden_states = () if output_hidden_states else None @@ -397,8 +384,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, **kwargs, ) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index a22e7960f..0e343a358 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -124,7 +124,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def qeff_apply_rotary_pos_emb(q, k, cos, sin): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -132,22 +132,9 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -265,9 +252,7 @@ def forward( ) # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos_cached, sin_cached, position_ids - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) if past_key_values is not None: cache_kwargs = { @@ -646,6 +631,8 @@ def forward( # embed positions hidden_states = inputs_embeds + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) for idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): # For text-only path we should skip cross attention layers. @@ -677,8 +664,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, ) hidden_states = self.norm(hidden_states) diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index fe2ebee12..d4e82c4ea 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -55,7 +55,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def qeff_apply_rotary_pos_emb(q, k, cos, sin): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -63,22 +63,9 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -137,11 +124,7 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - # kv_seq_len = key_states.shape[-2] - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos_cached, sin_cached, position_ids - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -281,6 +264,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -295,8 +280,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, **kwargs, ) diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index cf00205f4..1ea31a9d8 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -53,7 +53,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def qeff_apply_rotary_pos_emb(q, k, cos, sin): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -61,23 +61,10 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -148,9 +135,7 @@ def forward( # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos_cached, sin_cached, position_ids - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) if past_key_values is not None: cache_kwargs = { @@ -309,6 +294,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: @@ -323,8 +310,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, **kwargs, ) diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index a76113fd0..adfba2960 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -59,7 +59,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def qeff_apply_rotary_pos_emb(q, k, cos, sin): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). Explanation: @@ -76,25 +76,10 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - mrope_section(`List(int)`): - Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -152,9 +137,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos_cached, sin_cached, position_ids - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) if past_key_values is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -309,6 +292,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) for decoder_layer in self.layers: if output_hidden_states: @@ -323,8 +308,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, ) hidden_states = self.norm(hidden_states) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 7944dff65..4b60ba044 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -42,7 +42,15 @@ from QEfficient.utils.logging_utils import logger -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1): +def qeff_prepare_mrope_cos_sin(cos, sin, position_ids): + cos = cos[position_ids] + sin = sin[position_ids] + cos = torch.cat([cos[0, ..., 0:32], cos[1, ..., 32:80], cos[2, ..., 80:128]], dim=-1).unsqueeze(1) + sin = torch.cat([sin[0, ..., 0:32], sin[1, ..., 32:80], sin[2, ..., 80:128]], dim=-1).unsqueeze(1) + return cos, sin + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). Explanation: @@ -59,27 +67,10 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqu k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - mrope_section(`List(int)`): - Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids] - sin = sin[position_ids] - cos = torch.cat([cos[0, ..., 0:32], cos[1, ..., 32:80], cos[2, ..., 80:128]], dim=-1).unsqueeze(unsqueeze_dim) - sin = torch.cat([sin[0, ..., 0:32], sin[1, ..., 32:80], sin[2, ..., 80:128]], dim=-1).unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -585,9 +576,7 @@ def forward( # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos_cached, sin_cached, position_ids[1:], self.rope_scaling["mrope_section"] - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) if past_key_values is not None: if num_kv_blocks is not None: @@ -769,6 +758,7 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + cos, sin = qeff_prepare_mrope_cos_sin(self.cos_cached, self.sin_cached, position_ids[1:]) for decoder_layer in self.layers: if output_hidden_states: @@ -784,8 +774,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, **kwargs, ) diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index d1069f225..06bb06aae 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -59,7 +59,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def qeff_apply_rotary_pos_emb(q, k, cos, sin): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). Explanation: @@ -76,24 +76,10 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - mrope_section(`List(int)`): - Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -153,9 +139,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos_cached, sin_cached, position_ids - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) if past_key_values is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -310,6 +294,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) for decoder_layer in self.layers: if output_hidden_states: @@ -324,8 +310,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, ) hidden_states = self.norm(hidden_states) diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index f040e5ecf..ced55fcd9 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -51,7 +51,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def qeff_apply_rotary_pos_emb(q, k, cos, sin): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -59,23 +59,10 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -199,9 +186,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos_cached, sin_cached, position_ids - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) if past_key_values is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -334,6 +319,8 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None + sin = self.sin_cached[position_ids].unsqueeze(1) + cos = self.cos_cached[position_ids].unsqueeze(1) for decoder_layer in self.layers: if output_hidden_states: @@ -348,8 +335,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, ) hidden_states = self.norm(hidden_states) diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index c914d48c3..4822119ba 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -62,7 +62,15 @@ def qeff_apply_interleaved_mrope(freqs, mrope_section): return freqs_t -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1): +def qeff_prepare_mrope_cos_sin(cos, sin, position_ids, mrope_section): + cos = cos[position_ids] + sin = sin[position_ids] + cos = qeff_apply_interleaved_mrope(cos, mrope_section).unsqueeze(1) + sin = qeff_apply_interleaved_mrope(sin, mrope_section).unsqueeze(1) + return cos, sin + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). Explanation: @@ -79,27 +87,9 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqu k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - mrope_section(`List(int)`): - Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids] - sin = sin[position_ids] - cos = qeff_apply_interleaved_mrope(cos, mrope_section) - sin = qeff_apply_interleaved_mrope(sin, mrope_section) - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -390,41 +380,45 @@ def forward( past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, - key_states, - cos_cached, - sin_cached, - position_ids[1:], - self.config.rope_scaling["mrope_section"], - ) - if past_key_values is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin_cached, - "cos": cos_cached, - "batch_index": batch_index, - "position_ids": position_ids[0], - "past_seen_tokens": past_seen_tokens, - } - if comp_ctx_lengths is not None: - attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] - cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - cache_kwargs=cache_kwargs, - layer_idx=self.layer_idx, - past_key_values=past_key_values, - **kwargs, - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) + use_blocking = blocking_config is not None and (blocking_config.mode != BlockingMode.NONE) + if use_blocking: + attn_output, attn_weights = generic_blocked_attention_interface( + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + scaling=self.scaling, + layer_idx=self.layer_idx, + past_key_value=past_key_values, + blocking_config=blocking_config, + comp_ctx_length=comp_ctx_lengths, + batch_index=batch_index, + position_ids=position_ids[0], + past_seen_tokens=past_seen_tokens, + ) + else: + key_states, value_states, _ = past_key_value_update( + module=self, + key=key_states, + value=value_states, + attention_mask=attention_mask, + past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + position_ids=position_ids[0], + ) + attn_output, attn_weights = eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + **kwargs, + ) attn_output = attn_output.reshape(bsz, q_len, -1) @@ -562,6 +556,9 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids[1:]) + cos, sin = qeff_prepare_mrope_cos_sin( + self.cos_cached, self.sin_cached, position_ids[1:], self.config.rope_scaling["mrope_section"] + ) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -582,8 +579,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, **kwargs, ) diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 83cf3d40e..3e2eb1cf9 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -58,7 +58,18 @@ def qeff_apply_interleaved_mrope(freqs, mrope_section): return freqs_t -def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1): +def qeff_prepare_mrope_cos_sin(cos, sin, position_ids, mrope_section): + invalid_pos_mask = position_ids < 0 + safe_position_ids = torch.where(invalid_pos_mask, torch.zeros_like(position_ids), position_ids) + flat_pos = safe_position_ids.reshape(-1) + cos = cos.index_select(0, flat_pos).reshape(*safe_position_ids.shape, cos.shape[-1]) + sin = sin.index_select(0, flat_pos).reshape(*safe_position_ids.shape, sin.shape[-1]) + cos = qeff_apply_interleaved_mrope(cos, mrope_section).unsqueeze(1) + sin = qeff_apply_interleaved_mrope(sin, mrope_section).unsqueeze(1) + return cos, sin + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). Explanation: @@ -75,31 +86,9 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqu k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - mrope_section(`List(int)`): - Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - # Safe gather: map padded -1 IDs to 0 for gather, then zero them out after interleave. - invalid_pos_mask = position_ids < 0 - safe_position_ids = torch.where(invalid_pos_mask, torch.zeros_like(position_ids), position_ids) - flat_pos = safe_position_ids.reshape(-1) - cos = cos.index_select(0, flat_pos).reshape(*safe_position_ids.shape, cos.shape[-1]) - sin = sin.index_select(0, flat_pos).reshape(*safe_position_ids.shape, sin.shape[-1]) - cos = qeff_apply_interleaved_mrope(cos, mrope_section) - sin = qeff_apply_interleaved_mrope(sin, mrope_section) - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -390,40 +379,46 @@ def forward( query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, - key_states, - cos_cached, - sin_cached, - position_ids[1:], - self.config.rope_scaling["mrope_section"], - ) - if past_key_values is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin_cached, - "cos": cos_cached, - "batch_index": batch_index, - "position_ids": position_ids[0], - } - if comp_ctx_lengths is not None: - attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] - cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - cache_kwargs=cache_kwargs, - layer_idx=self.layer_idx, - past_key_values=past_key_values, - **kwargs, - ) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) + past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0 + blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) + use_blocking = blocking_config is not None and (blocking_config.mode != BlockingMode.NONE) + if use_blocking: + attn_output, attn_weights = generic_blocked_attention_interface( + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + scaling=self.scaling, + layer_idx=self.layer_idx, + past_key_value=past_key_values, + blocking_config=blocking_config, + comp_ctx_length=comp_ctx_lengths, + batch_index=batch_index, + position_ids=position_ids[0], + past_seen_tokens=past_seen_tokens, + ) + else: + key_states, value_states, _ = past_key_value_update( + module=self, + key=key_states, + value=value_states, + attention_mask=attention_mask, + past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + position_ids=position_ids[0], + ) + attn_output, attn_weights = eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + **kwargs, + ) attn_output = attn_output.reshape(bsz, q_len, -1) @@ -562,6 +557,9 @@ def forward( hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids[1:]) + cos, sin = qeff_prepare_mrope_cos_sin( + self.cos_cached, self.sin_cached, position_ids[1:], self.config.rope_scaling["mrope_section"] + ) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -582,8 +580,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, - sin_cached=self.sin_cached, - cos_cached=self.cos_cached, + sin_cached=sin, + cos_cached=cos, **kwargs, ) diff --git a/tests/transformers/models/test_single_subfunction.py b/tests/transformers/models/test_single_subfunction.py index f17edab65..91b7f5527 100644 --- a/tests/transformers/models/test_single_subfunction.py +++ b/tests/transformers/models/test_single_subfunction.py @@ -85,10 +85,28 @@ def test_subfunction_vs_nonsubfunction(config, tmp_path): keywords = ["DecoderLayer", "Block", "Layer"] filtered = [name for name in functions_names if any(key in name for key in keywords)] - if len(filtered) > 1: - raise AssertionError(f"function definition, but found {len(functions_names)} functions: {functions_names}") + assert len(filtered) == 1, f"Expected a single decoder subfunction, found {len(filtered)}: {functions_names}" if not get_available_device_id(): pytest.skip("No available devices to run model on Cloud AI 100") compile_params = {"prefill_seq_len": 8, "ctx_len": 16} model_0_0.compile(onnx_path=with_sub_func_onnx, **compile_params, use_onnx_subfunctions=True) + + +@pytest.mark.feature +def test_tinyllama_exports_single_decoder_subfunction(tmp_path): + model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + + try: + qeff_model = QEFFAutoModelForCausalLM( + AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs), + cb=False, + ) + except Exception as exc: + pytest.skip(f"Skipping {model_id}: unable to load model in this environment ({type(exc).__name__}: {exc})") + + with_sub_func_onnx = qeff_model.export(tmp_path, use_onnx_subfunctions=True, offload_pt_weights=False) + functions_names = get_function(with_sub_func_onnx) + filtered = [name for name in functions_names if "DecoderLayer" in name] + + assert len(filtered) == 1, f"Expected a single decoder subfunction, found {len(filtered)}: {functions_names}" From 835b4bd64abe6a626063e6a6d58c86982cea5164 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Wed, 22 Apr 2026 12:27:25 +0530 Subject: [PATCH 2/2] rebsae with relesae branch Signed-off-by: Rishin Raj --- .../models/gpt_oss/modeling_gpt_oss.py | 1 - .../models/granitemoe/modeling_granitemoe.py | 60 +++++++---------- .../models/llama/modeling_llama.py | 40 +++++------ .../models/qwen3_vl/modeling_qwen3_vl.py | 66 ++++++++----------- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 66 ++++++++----------- 5 files changed, 96 insertions(+), 137 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 40e7ac23a..6a64a4ccb 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -893,7 +893,6 @@ def forward( query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0 query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) if past_key_values is not None: diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 3ff2161a2..e1fcb0291 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -112,44 +112,28 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) - past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0 - blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) - use_blocking = blocking_config is not None and (blocking_config.mode != BlockingMode.NONE) - if use_blocking: - attn_output, attn_weights = generic_blocked_attention_interface( - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - scaling=self.scaling, - layer_idx=self.layer_idx, - past_key_value=past_key_values, - blocking_config=blocking_config, - comp_ctx_length=comp_ctx_lengths, - batch_index=batch_index, - position_ids=position_ids, - past_seen_tokens=past_seen_tokens, - ) - else: - key_states, value_states, _ = past_key_value_update( - module=self, - key=key_states, - value=value_states, - attention_mask=attention_mask, - past_key_value=past_key_values, - comp_ctx_lengths=comp_ctx_lengths, - batch_index=batch_index, - position_ids=position_ids, - ) - attn_output, attn_weights = eager_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - scaling=self.scaling, - ) + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin_cached, + "cos": cos_cached, + "cache_position": cache_position, + "batch_index": batch_index, + "position_ids": position_ids, + } + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attn_output, attn_weights = eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + ) attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 760def7af..b6cdd2cd5 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -203,26 +203,28 @@ def forward( value_states = self.v_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) - blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) - use_blocking = blocking_config is not None and (blocking_config.mode != BlockingMode.NONE) - if use_blocking: - attn_output, attn_weights = generic_blocked_attention_interface( - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - scaling=self.scaling, - layer_idx=self.layer_idx, - past_key_value=past_key_values, - blocking_config=blocking_config, - comp_ctx_length=comp_ctx_lengths, - batch_index=batch_index, - position_ids=position_ids, - past_seen_tokens=past_seen_tokens, - ) + cache_kwargs = None + if past_key_values is not None: + if num_kv_blocks is not None: + cache_kwargs = { + "batch_index": batch_index, + "position_ids": position_ids, + "past_seen_tokens": past_seen_tokens, + } + past_key_values.write_only(key_states, value_states, self.layer_idx, cache_kwargs) + else: + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + if num_kv_blocks is not None: + attention_interface = eager_attention_forward_blockedKV else: attention_interface = eager_attention_forward diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 4822119ba..b5140cd50 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- import math -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -381,44 +381,32 @@ def forward( past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) - blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) - use_blocking = blocking_config is not None and (blocking_config.mode != BlockingMode.NONE) - if use_blocking: - attn_output, attn_weights = generic_blocked_attention_interface( - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - scaling=self.scaling, - layer_idx=self.layer_idx, - past_key_value=past_key_values, - blocking_config=blocking_config, - comp_ctx_length=comp_ctx_lengths, - batch_index=batch_index, - position_ids=position_ids[0], - past_seen_tokens=past_seen_tokens, - ) - else: - key_states, value_states, _ = past_key_value_update( - module=self, - key=key_states, - value=value_states, - attention_mask=attention_mask, - past_key_value=past_key_values, - comp_ctx_lengths=comp_ctx_lengths, - batch_index=batch_index, - position_ids=position_ids[0], - ) - attn_output, attn_weights = eager_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - scaling=self.scaling, - **kwargs, - ) + cache_kwargs = None + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin_cached, + "cos": cos_cached, + "batch_index": batch_index, + "position_ids": position_ids[0], + "past_seen_tokens": past_seen_tokens, + } + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attn_output, attn_weights = eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + cache_kwargs=cache_kwargs, + layer_idx=self.layer_idx, + past_key_value=past_key_values, + **kwargs, + ) attn_output = attn_output.reshape(bsz, q_len, -1) diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 3e2eb1cf9..1c5cafe9b 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- import math -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -380,45 +380,31 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached) - past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0 - blocking_config = getattr(self, "attn_blocking_config", AttentionBlockingConfig()) - use_blocking = blocking_config is not None and (blocking_config.mode != BlockingMode.NONE) - if use_blocking: - attn_output, attn_weights = generic_blocked_attention_interface( - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - scaling=self.scaling, - layer_idx=self.layer_idx, - past_key_value=past_key_values, - blocking_config=blocking_config, - comp_ctx_length=comp_ctx_lengths, - batch_index=batch_index, - position_ids=position_ids[0], - past_seen_tokens=past_seen_tokens, - ) - else: - key_states, value_states, _ = past_key_value_update( - module=self, - key=key_states, - value=value_states, - attention_mask=attention_mask, - past_key_value=past_key_values, - comp_ctx_lengths=comp_ctx_lengths, - batch_index=batch_index, - position_ids=position_ids[0], - ) - attn_output, attn_weights = eager_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - scaling=self.scaling, - **kwargs, - ) + cache_kwargs = None + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin_cached, + "cos": cos_cached, + "batch_index": batch_index, + "position_ids": position_ids[0], + } + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attn_output, attn_weights = eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + cache_kwargs=cache_kwargs, + layer_idx=self.layer_idx, + past_key_value=past_key_values, + **kwargs, + ) attn_output = attn_output.reshape(bsz, q_len, -1)