From 921228385df283c127ca917397498060cfc92f98 Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Tue, 24 Mar 2026 13:08:28 +0000 Subject: [PATCH 1/9] Added all the changes of rope_fix Signed-off-by: abhishek-singh591 --- .../models/falcon/modeling_falcon.py | 35 ++++---- .../models/gemma/modeling_gemma.py | 35 ++++---- .../models/gemma2/modeling_gemma2.py | 35 ++++---- .../models/gpt_oss/modeling_gpt_oss.py | 73 +++++++++------ .../models/granite/modeling_granite.py | 35 ++++---- .../models/granitemoe/modeling_granitemoe.py | 37 ++++---- .../models/llama/modeling_llama.py | 35 ++++---- .../llama_swiftkv/modeling_llama_swiftkv.py | 24 ++++- .../models/mistral/modeling_mistral.py | 35 ++++---- .../models/mixtral_moe/modeling_mixtral.py | 35 ++++---- .../models/mllama/modeling_mllama.py | 37 ++++---- .../models/olmo2/modeling_olmo2.py | 37 ++++---- .../transformers/models/phi3/modeling_phi3.py | 35 ++++---- .../models/qwen2/modeling_qwen2.py | 35 ++++---- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 37 ++++---- .../models/qwen3/modeling_qwen3.py | 35 ++++---- .../models/qwen3_moe/modeling_qwen3_moe.py | 35 ++++---- .../causallm/example_pytorch_transforms.py | 12 +-- .../models/test_single_subfunction.py | 90 +++++++++++++++++++ 19 files changed, 467 insertions(+), 265 deletions(-) create mode 100644 tests/transformers/models/test_single_subfunction.py diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 4ebb2fb96e..41f2206609 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -59,16 +59,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -108,9 +98,6 @@ class QEffFalconAttention(FalconAttention): - add new args position idx for the cache_kwargs for kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffFalconRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -125,6 +112,8 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, ): fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads @@ -137,8 +126,10 @@ def forward( key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[: hidden_states.shape[1]].to(dtype=value_layer.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_layer.dtype) + cos, sin = cos_cached, sin_cached query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) if layer_past is not None: @@ -184,6 +175,8 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, **kwargs, ): residual = hidden_states @@ -208,6 +201,8 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, ) if not self.config.new_decoder_architecture: @@ -245,6 +240,14 @@ class QEffFalconModel(FalconModel): - update causal attention mask """ + def __qeff_init__(self): + self.rotary_emb = QEffFalconRotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) + def forward( self, input_ids: torch.LongTensor = None, @@ -322,6 +325,8 @@ def forward( output_attentions=output_attentions, alibi=alibi, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) hidden_states = outputs[0] diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 260d1857a7..f13f163336 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -55,16 +55,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -128,9 +118,6 @@ class QEffGemmaAttention(GemmaAttention): - add new args cache idx for the kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffGemmaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -140,6 +127,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -149,8 +138,10 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -194,6 +185,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -223,6 +216,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states @@ -243,6 +238,14 @@ class QEffGemmaModel(GemmaModel): - add new args cache idx for the kv retention """ + def __qeff_init__(self): + self.rotary_emb = QEffGemmaRotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: torch.LongTensor = None, @@ -310,6 +313,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 6dee8c85dd..3240e9089e 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -58,16 +58,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -135,9 +125,6 @@ class QEffGemma2Attention(Gemma2Attention): - add new args cache idx for the kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffGemma2RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -147,6 +134,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -156,8 +145,10 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -208,6 +199,8 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -241,6 +234,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -271,6 +266,14 @@ class QEffGemma2Model(Gemma2Model): - add new args cache idx for the kv retention """ + def __qeff_init__(self): + self.rotary_emb = QEffGemma2RotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: torch.LongTensor = None, @@ -355,6 +358,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index e8f5fa89b3..dd53ab60eb 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -527,16 +527,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -737,9 +727,6 @@ def opt_eager_attention_forward_blocked( class QEffPrefillOnlyChunkedGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -751,6 +738,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, sliding_mask=None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] @@ -759,9 +748,11 @@ def forward( hidden_shape = (*input_shape, -1, self.head_dim) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): - max_seq_len_cached = 32 * 1024 - cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + # if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + # max_seq_len_cached = 32 * 1024 + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -823,9 +814,6 @@ def forward( class QEffPrefillOnlyGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -837,6 +825,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, sliding_mask=None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] @@ -845,9 +835,11 @@ def forward( hidden_shape = (*input_shape, -1, self.head_dim) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): - max_seq_len_cached = 32 * 1024 - cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + # if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + # max_seq_len_cached = 32 * 1024 + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -905,9 +897,6 @@ def forward( class QEffGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -919,6 +908,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, sliding_mask=None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] @@ -927,9 +918,11 @@ def forward( query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): - max_seq_len_cached = 32 * 1024 - cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + # if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + # max_seq_len_cached = 32 * 1024 + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -986,6 +979,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC sliding_mask=None, + sin_cached=None, + cos_cached=None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor]: residual = hidden_states @@ -1002,6 +997,8 @@ def forward( cache_position=cache_position, position_embeddings=position_embeddings, sliding_mask=sliding_mask, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states @@ -1024,6 +1021,14 @@ def forward( class QEffPrefillOnlyGptOssModel(GptOssModel): + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1093,6 +1098,8 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, sliding_mask=sliding_mask, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) hidden_states = layer_outputs[0] @@ -1115,6 +1122,14 @@ def forward( class QEffGptOssModel(GptOssModel): + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1187,6 +1202,8 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, sliding_mask=sliding_mask, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) hidden_states = layer_outputs[0] diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index 8a32c52ef2..d527e63ce2 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -53,16 +53,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -121,9 +111,6 @@ def eager_attention_forward( class QEffGraniteAttention(GraniteAttention): - def __qeff_init__(self): - self.rotary_emb = QEffGraniteRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -133,6 +120,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -142,8 +131,10 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -192,6 +183,8 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + sin_cached=None, + cos_cached=None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -230,6 +223,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states * self.residual_multiplier @@ -249,6 +244,14 @@ def forward( class QEffGraniteModel(GraniteModel): + def __qeff_init__(self): + self.rotary_emb = QEffGraniteRotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: torch.LongTensor = None, @@ -316,6 +319,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 935df7c2d9..8ca4c10b3c 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -59,16 +59,6 @@ def _set_cos_sin_cache(self, seq_len: int, device=None, dtype=None): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x: torch.Tensor, seq_len: int = None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb( q: torch.Tensor, @@ -111,9 +101,6 @@ def qeff_apply_rotary_pos_emb( class QEffGraniteMoeAttention(GraniteMoeAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffGraniteMoeRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -126,6 +113,8 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -137,8 +126,10 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -214,6 +205,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + sin_cached=None, + cos_cached=None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -255,6 +248,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) @@ -287,6 +282,14 @@ class QEffGraniteMoeModel(GraniteMoeModel): """ + def __qeff_init__(self): + self.rotary_emb = QEffGraniteMoeRotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: torch.LongTensor = None, @@ -356,6 +359,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) else: layer_outputs = decoder_layer( @@ -368,6 +373,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) hidden_states = layer_outputs[0] diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 57bccdb1bb..e9bfd89bfc 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -54,16 +54,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -198,9 +188,6 @@ def eager_attention_forward_blockedKV( class QEffLlamaAttention(LlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -212,6 +199,8 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, num_kv_blocks: Optional[torch.Tensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -226,9 +215,11 @@ def forward( key_states = self.k_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -287,6 +278,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -303,6 +296,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states @@ -321,6 +316,14 @@ class QEffLlamaModel(LlamaModel): Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py """ + def __qeff_init__(self): + self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: torch.LongTensor = None, @@ -380,6 +383,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index e219d5e03a..9ebf7e4371 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -82,8 +82,6 @@ def __init__(self, config: QEffLlamaSwiftKVConfig, layer_idx) -> None: ) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = QEffLlamaRotaryEmbedding(config=config) - def forward( self, hidden_states: torch.Tensor, @@ -92,6 +90,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: torch.Tensor = None, batch_index: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, ) -> torch.Tensor: bsz, q_len, _ = hidden_states.size() q_len = 1 # as we always run this for single token @@ -110,10 +110,12 @@ def forward( if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - kv_seq_len = past_key_value.get_seq_length(self.layer_idx) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx) key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached position_ids = position_ids[torch.arange(bsz), position_ids.to(torch.int32).argmax(1)].unsqueeze(1) query_states, _ = qeff_apply_rotary_pos_emb( query_states, torch.empty_like(query_states), cos, sin, position_ids @@ -162,6 +164,8 @@ def forward( comp_ctx_lengths, causal_mask, batch_index: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states @@ -174,6 +178,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, batch_index=batch_index, + sin_cached=sin_cached, + cos_cached=cos_cached, ) hidden_states = residual + hidden_states @@ -206,6 +212,14 @@ def __init__(self, config: QEffLlamaSwiftKVConfig): self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm_swiftkv = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def __qeff_init__(self): + self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def _run_swiftkv_layers( self, hidden_states: torch.Tensor, @@ -347,6 +361,8 @@ def forward( batch_index=batch_index, output_attentions=False, use_cache=True, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) bsz, q_len, _ = hidden_states.size() diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 47107384ed..48f71a89b1 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -58,16 +58,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -131,9 +121,6 @@ class QEffMistralAttention(MistralAttention): - add new args cache idx for the kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffMistralRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -146,6 +133,8 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] = None, # kept here for BC + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -159,8 +148,10 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -205,6 +196,8 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -236,6 +229,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states @@ -256,6 +251,14 @@ class QEffMistralModel(MistralModel): - add new args cache idx for the kv retention """ + def __qeff_init__(self): + self.rotary_emb = QEffMistralRotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: torch.LongTensor = None, @@ -328,6 +331,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index 680c839ae5..c5a0760787 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -60,16 +60,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -128,9 +118,6 @@ def eager_attention_forward( class QEffMixtralAttention(MixtralAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffMixtralRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -139,6 +126,8 @@ def forward( past_key_value: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -155,8 +144,10 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx) + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -265,6 +256,8 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -301,6 +294,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states @@ -323,6 +318,14 @@ class QEffMixtralModel(MixtralModel): - update causal attention mask """ + def __qeff_init__(self): + self.rotary_emb = QEffMixtralRotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + # Ignore copy def forward( self, @@ -397,6 +400,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 3cba022b48..f91cf04357 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -123,16 +123,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -241,9 +231,6 @@ class QEffMllamaTextSelfAttention(MllamaTextSelfAttention): - add new args cache idx for the kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffMllamaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -255,6 +242,8 @@ def forward( position_embeddings: torch.Tensor = None, use_cache: bool = False, cache_position=None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ): bsz, q_len, _ = hidden_states.size() @@ -274,9 +263,11 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -326,6 +317,8 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + sin_cached=None, + cos_cached=None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -361,6 +354,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + sin_cached=sin_cached, + cos_cached=cos_cached, ) hidden_states = residual + hidden_states @@ -465,6 +460,8 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.Tensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -602,6 +599,14 @@ class QEffMllamaTextModel(MllamaTextModel): - add new args cache idx for the kv retention """ + def __qeff_init__(self): + self.rotary_emb = QEffMllamaRotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -676,6 +681,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) hidden_states = self.norm(hidden_states) diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index c79ad7faee..1bda9df300 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -54,16 +54,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -123,9 +113,6 @@ def eager_attention_forward( class QEffOlmo2Attention(Olmo2Attention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffOlmo2RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -135,6 +122,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -148,10 +137,12 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] + # kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -198,6 +189,8 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -213,6 +206,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -233,6 +228,14 @@ class QEffOlmo2Model(Olmo2Model): - add new args cache idx for the kv retention """ + def __qeff_init__(self): + self.rotary_emb = QEffOlmo2RotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: torch.LongTensor = None, @@ -297,6 +300,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index b48ab28979..585c0480d5 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -52,16 +52,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -129,9 +119,6 @@ class QEffPhi3Attention(Phi3Attention): - add new args position idx for the cache_kwargs for kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffPhi3RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -142,6 +129,8 @@ def forward( past_key_value: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -157,8 +146,10 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -207,6 +198,8 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -244,6 +237,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) @@ -265,6 +260,14 @@ class QEffPhi3Model(Phi3Model): - update causal attention mask """ + def __qeff_init__(self): + self.rotary_emb = QEffPhi3RotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: torch.LongTensor = None, @@ -324,6 +327,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 841df65269..6e2ec30c3f 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -58,16 +58,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). @@ -141,9 +131,6 @@ class QEffQwen2Attention(Qwen2Attention): - add new args position idx for the cache_kwargs for kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffQwen2RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -153,6 +140,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -162,8 +151,10 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -208,6 +199,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -240,6 +233,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states @@ -261,6 +256,14 @@ class QEffQwen2Model(Qwen2Model): - update causal attention mask """ + def __qeff_init__(self): + self.rotary_emb = QEffQwen2RotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: torch.LongTensor = None, @@ -324,6 +327,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) hidden_states = self.norm(hidden_states) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index d6bfbda81b..8a5ddca9c7 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -348,16 +348,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def eager_attention_forward_blockedKV( module: nn.Module, @@ -564,9 +554,6 @@ class QEffQwen2_5_VLAttention(Qwen2_5_VLAttention): and "Generating Long Sequences with Sparse Transformers". """ - def __qeff_init__(self): - self.rotary_emb = QEffQwen2_5_VLRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -579,6 +566,8 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, num_kv_blocks: Optional[torch.Tensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -591,11 +580,13 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + # kv_seq_len = key_states.shape[-2] + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids[1:], self.rope_scaling["mrope_section"] @@ -661,6 +652,8 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, # position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -701,6 +694,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states @@ -723,6 +718,14 @@ def forward( class QEffQwen2_5_VLTextModel(Qwen2_5_VLTextModel): + def __qeff_init__(self): + self.rotary_emb = QEffQwen2_5_VLRotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -785,6 +788,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index ccc4bbac29..da9fde5589 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -58,16 +58,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). @@ -142,9 +132,6 @@ class QEffQwen3Attention(Qwen3Attention): - add new args position idx for the cache_kwargs for kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffQwen3RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -154,6 +141,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -163,8 +152,10 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -209,6 +200,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -241,6 +234,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states @@ -262,6 +257,14 @@ class QEffQwen3Model(Qwen3Model): - update causal attention mask """ + def __qeff_init__(self): + self.rotary_emb = QEffQwen3RotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: torch.LongTensor = None, @@ -325,6 +328,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) hidden_states = self.norm(hidden_states) diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 6bdd5e2439..7dabca1971 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -50,16 +50,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -188,9 +178,6 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens class QEffQwen3MoeAttention(Qwen3MoeAttention): - def __qeff_init__(self): - self.rotary_emb = QEffQwen3MoeRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -200,6 +187,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -209,8 +198,10 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -247,6 +238,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -279,6 +272,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, ) hidden_states = residual + hidden_states @@ -296,6 +291,14 @@ def forward( class QEffQwen3MoeModel(Qwen3MoeModel): + def __qeff_init__(self): + self.rotary_emb = QEffQwen3MoeRotaryEmbedding(config=self.config) + self.rotary_emb._set_cos_sin_cache( + seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype + ) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -349,6 +352,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) hidden_states = self.norm(hidden_states) diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py index ff62588f9c..503efc12dc 100644 --- a/examples/onboarding_guide/causallm/example_pytorch_transforms.py +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -27,12 +27,6 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union -from QEfficient.transformers.models.blueprint.modeling_blueprint import ( - QEffBlueprintAttention, - QEffBlueprintDecoderLayer, - QEffBlueprintForCausalLM, - QEffBlueprintModel, -) from torch import nn # Example imports for three representative models @@ -62,6 +56,12 @@ from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform from QEfficient.customop import CustomRMSNormAIC from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaDecoderLayer, diff --git a/tests/transformers/models/test_single_subfunction.py b/tests/transformers/models/test_single_subfunction.py new file mode 100644 index 0000000000..968800918d --- /dev/null +++ b/tests/transformers/models/test_single_subfunction.py @@ -0,0 +1,90 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import onnx +import pytest +import torch +from transformers import AutoConfig, AutoModelForCausalLM + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM +from QEfficient.utils.device_utils import get_available_device_id + +torch.manual_seed(42) + +configs = [ + ("gpt2", 256, 2, 4, 128, 512, 127, {}), + # ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("falcon", 256, 2, 4, 128, 512, 127, {}), + ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mpt", 256, 2, 4, 128, 512, 127, {}), + # ("phi", 256, 2, 4, 128, 512, 127, {}), + ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), + ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("qwen3", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("starcoder2", 256, 2, 4, 128, 512, 127, {}), + ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("qwen3_moe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("granitemoe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("gemma", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("gemma2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), +] + +configs = [ + AutoConfig.for_model( + model_name, + max_position_embeddings=max_position_embeddings, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + vocab_size=vocab_size, + **additional_params, + ) + for ( + model_name, + max_position_embeddings, + num_hidden_layers, + num_attention_heads, + hidden_size, + intermediate_size, + vocab_size, + additional_params, + ) in configs +] + +model_kwargs = {"attn_implementation": "eager"} +config_ids = [x.model_type for x in configs] + + +def get_function(onnx_path): + """Check if ONNX model contains QEffGPT2Block function definition.""" + model = onnx.load(onnx_path, load_external_data=False) + function_names = [f.name for f in model.functions] + return function_names + + +@pytest.mark.on_qaic +@pytest.mark.feature +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_subfunction_vs_nonsubfunction(config, tmp_path): + model_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb=False) + with_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=True, offload_pt_weights=False) + + functions_names = get_function(with_sub_func_onnx) + print(functions_names) + if len(functions_names) != 12: + raise AssertionError(f"function definition, but found {len(functions_names)} functions: {functions_names}") + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + compile_params = {"prefill_seq_len": 8, "ctx_len": 16} + model_0_0.compile(onnx_path=with_sub_func_onnx, **compile_params, use_onnx_subfunctions=True) From 8acdd179c76ae0c43eefb2fad797c238db13182a Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Tue, 24 Mar 2026 13:13:32 +0000 Subject: [PATCH 2/9] lint Signed-off-by: abhishek-singh591 --- .../causallm/example_pytorch_transforms.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py index 503efc12dc..ff62588f9c 100644 --- a/examples/onboarding_guide/causallm/example_pytorch_transforms.py +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -27,6 +27,12 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) from torch import nn # Example imports for three representative models @@ -56,12 +62,6 @@ from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform from QEfficient.customop import CustomRMSNormAIC from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function -from QEfficient.transformers.models.blueprint.modeling_blueprint import ( - QEffBlueprintAttention, - QEffBlueprintDecoderLayer, - QEffBlueprintForCausalLM, - QEffBlueprintModel, -) from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaDecoderLayer, From c225c0deca45c8c4b057e8e25528e28048b0158e Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Wed, 25 Mar 2026 03:56:51 +0000 Subject: [PATCH 3/9] made minnor fix Signed-off-by: abhishek-singh591 --- .../transformers/models/falcon/modeling_falcon.py | 6 +++--- .../transformers/models/gemma/modeling_gemma.py | 6 +++--- .../transformers/models/gemma2/modeling_gemma2.py | 6 +++--- .../transformers/models/gpt_oss/modeling_gpt_oss.py | 12 ++++++------ .../transformers/models/granite/modeling_granite.py | 10 +++++----- .../models/granitemoe/modeling_granitemoe.py | 10 +++++----- .../transformers/models/llama/modeling_llama.py | 10 +++++----- .../models/llama_swiftkv/modeling_llama_swiftkv.py | 12 ++++++------ .../transformers/models/mistral/modeling_mistral.py | 6 +++--- .../models/mixtral_moe/modeling_mixtral.py | 11 ++++++----- .../transformers/models/mllama/modeling_mllama.py | 10 +++++----- .../transformers/models/olmo2/modeling_olmo2.py | 12 ++++++------ QEfficient/transformers/models/phi3/modeling_phi3.py | 6 +++--- .../transformers/models/qwen2/modeling_qwen2.py | 10 +++++----- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 12 ++++++------ .../transformers/models/qwen3/modeling_qwen3.py | 10 +++++----- .../models/qwen3_moe/modeling_qwen3_moe.py | 6 +++--- tests/transformers/models/test_single_subfunction.py | 6 +++++- 18 files changed, 83 insertions(+), 78 deletions(-) diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 41f2206609..9d30d6ffbb 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -126,9 +126,9 @@ def forward( key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos_cached[: hidden_states.shape[1]].to(dtype=value_layer.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_layer.dtype) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[:kv_seq_len].to(dtype=value_layer.dtype) + sin_cached[:kv_seq_len].to(dtype=value_layer.dtype) cos, sin = cos_cached, sin_cached query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index f13f163336..0ac965cb5e 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -138,9 +138,9 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[:kv_seq_len].to(dtype=value_states.dtype) + sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 3240e9089e..9500a7253f 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -145,9 +145,9 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[:kv_seq_len].to(dtype=value_states.dtype) + sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index dd53ab60eb..7884049614 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -748,10 +748,10 @@ def forward( hidden_shape = (*input_shape, -1, self.head_dim) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): - # max_seq_len_cached = 32 * 1024 - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + max_seq_len_cached = 32 * 1024 + cos_cached[:max_seq_len_cached].to(dtype=value_states.dtype) + sin_cached[:max_seq_len_cached].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -1026,8 +1026,8 @@ def __qeff_init__(self): self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) - self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) - self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) def forward( self, diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index d527e63ce2..b1724d92e6 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -131,9 +131,9 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[:kv_seq_len].to(dtype=value_states.dtype) + sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -249,8 +249,8 @@ def __qeff_init__(self): self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) - self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) - self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) def forward( self, diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 8ca4c10b3c..d13504f8d4 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -126,9 +126,9 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[:kv_seq_len].to(dtype=value_states.dtype) + sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -287,8 +287,8 @@ def __qeff_init__(self): self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) - self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) - self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) def forward( self, diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index e9bfd89bfc..d4c4903000 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -215,10 +215,10 @@ def forward( key_states = self.k_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos_cached[:kv_seq_len].to(dtype=value_states.dtype) + sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -321,8 +321,8 @@ def __qeff_init__(self): self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) - self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) - self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) def forward( self, diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 9ebf7e4371..3a04e98004 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -110,11 +110,11 @@ def forward( if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs) - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos_cached[:kv_seq_len].to(dtype=value_states.dtype) + sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached position_ids = position_ids[torch.arange(bsz), position_ids.to(torch.int32).argmax(1)].unsqueeze(1) query_states, _ = qeff_apply_rotary_pos_emb( @@ -217,8 +217,8 @@ def __qeff_init__(self): self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) - self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) - self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) def _run_swiftkv_layers( self, @@ -388,7 +388,7 @@ def forward( ) kv_seq_len = past_key_values.get_seq_length(self_attn.layer_idx) - cos, sin = self_attn.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.cos_cached[:kv_seq_len], self.sin_cached[:kv_seq_len] _, key_states = qeff_apply_rotary_pos_emb(torch.empty_like(key_states), key_states, cos, sin, position_ids) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} past_key_values.write_only(key_states, value_states, self_attn.layer_idx, cache_kwargs) diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 48f71a89b1..f254aa0f0a 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -148,9 +148,9 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[:kv_seq_len].to(dtype=value_states.dtype) + sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index c5a0760787..a27f6d2462 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -144,9 +144,10 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx) - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx) + + cos_cached[:kv_seq_len].to(dtype=value_states.dtype) + sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -323,8 +324,8 @@ def __qeff_init__(self): self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) - self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) - self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) # Ignore copy def forward( diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index f91cf04357..94de5faa0d 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -263,10 +263,10 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos_cached[:kv_seq_len].to(dtype=value_states.dtype) + sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -604,8 +604,8 @@ def __qeff_init__(self): self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) - self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) - self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) def forward( self, diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index 1bda9df300..d89ce47383 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -137,11 +137,11 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - # kv_seq_len = key_states.shape[-2] + kv_seq_len = key_states.shape[-2] - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[:kv_seq_len].to(dtype=value_states.dtype) + sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -233,8 +233,8 @@ def __qeff_init__(self): self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) - self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) - self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) def forward( self, diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index 585c0480d5..4fa05fa8bc 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -146,9 +146,9 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[:kv_seq_len].to(dtype=value_states.dtype) + sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 6e2ec30c3f..85575196c1 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -151,9 +151,9 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[:kv_seq_len].to(dtype=value_states.dtype) + sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -261,8 +261,8 @@ def __qeff_init__(self): self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) - self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) - self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) def forward( self, diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 8a5ddca9c7..9f7cbde56e 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -580,12 +580,12 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - # kv_seq_len = key_states.shape[-2] - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + kv_seq_len = key_states.shape[-2] + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + cos_cached[:kv_seq_len].to(dtype=value_states.dtype) + sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb( @@ -723,8 +723,8 @@ def __qeff_init__(self): self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) - self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) - self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) def forward( self, diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index da9fde5589..bcaa75dd4c 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -152,9 +152,9 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[:kv_seq_len].to(dtype=value_states.dtype) + sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -262,8 +262,8 @@ def __qeff_init__(self): self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) - self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) - self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) def forward( self, diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 7dabca1971..3a1f5dfa3d 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -198,9 +198,9 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) + kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + cos_cached[:kv_seq_len].to(dtype=value_states.dtype) + sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/tests/transformers/models/test_single_subfunction.py b/tests/transformers/models/test_single_subfunction.py index 968800918d..f17edab654 100644 --- a/tests/transformers/models/test_single_subfunction.py +++ b/tests/transformers/models/test_single_subfunction.py @@ -81,7 +81,11 @@ def test_subfunction_vs_nonsubfunction(config, tmp_path): functions_names = get_function(with_sub_func_onnx) print(functions_names) - if len(functions_names) != 12: + + keywords = ["DecoderLayer", "Block", "Layer"] + filtered = [name for name in functions_names if any(key in name for key in keywords)] + + if len(filtered) > 1: raise AssertionError(f"function definition, but found {len(functions_names)} functions: {functions_names}") if not get_available_device_id(): From 3f828ceee6660be746d552356892aec35d1b5764 Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Wed, 25 Mar 2026 05:17:20 +0000 Subject: [PATCH 4/9] Added few changes Signed-off-by: abhishek-singh591 --- .../models/falcon/modeling_falcon.py | 6 +++ .../models/gemma/modeling_gemma.py | 7 ++++ .../models/gemma2/modeling_gemma2.py | 7 ++++ .../models/gpt_oss/modeling_gpt_oss.py | 8 ++++ .../models/granite/modeling_granite.py | 7 ++++ .../models/granitemoe/modeling_granitemoe.py | 7 ++++ .../models/llama/modeling_llama.py | 7 ++++ .../llama_swiftkv/modeling_llama_swiftkv.py | 9 +++++ .../models/mistral/modeling_mistral.py | 7 ++++ .../models/mixtral_moe/modeling_mixtral.py | 7 ++++ .../models/mllama/modeling_mllama.py | 7 ++++ .../models/olmo2/modeling_olmo2.py | 7 ++++ .../transformers/models/phi3/modeling_phi3.py | 7 ++++ .../models/qwen2/modeling_qwen2.py | 7 ++++ .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 7 ++++ .../models/qwen3/modeling_qwen3.py | 7 ++++ .../models/qwen3_moe/modeling_qwen3_moe.py | 7 ++++ .../models/test_single_subfunction.py | 38 +++++++++---------- 18 files changed, 140 insertions(+), 19 deletions(-) diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 9d30d6ffbb..4278e7f87d 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -42,6 +42,8 @@ class QEffFalconRotaryEmbedding(FalconRotaryEmbedding): - Add static sin/cos computations. """ + _max_seq_len_cached = 0 + def __init__(self, config: FalconConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. @@ -127,6 +129,9 @@ def forward( value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + + if kv_seq_len > QEffFalconRotaryEmbedding._max_seq_len_cached: + self._set_cos_sin_cache(seq_len=kv_seq_len, device=value_layer.device, dtype=value_layer.dtype) cos_cached[:kv_seq_len].to(dtype=value_layer.dtype) sin_cached[:kv_seq_len].to(dtype=value_layer.dtype) cos, sin = cos_cached, sin_cached @@ -245,6 +250,7 @@ def __qeff_init__(self): self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) + QEffFalconRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 0ac965cb5e..badc6ee3a0 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -37,6 +37,8 @@ class QEffGemmaRotaryEmbedding(GemmaRotaryEmbedding): - Add static sin/cos computations. """ + _max_seq_len_cached = 0 + def __init__(self, config: GemmaConfig, device=None): super().__init__(config=config) @@ -139,6 +141,10 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + if kv_seq_len > QEffGemmaRotaryEmbedding._max_seq_len_cached: + QEffGemmaRotaryEmbedding._set_cos_sin_cache( + seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype + ) cos_cached[:kv_seq_len].to(dtype=value_states.dtype) sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached @@ -240,6 +246,7 @@ class QEffGemmaModel(GemmaModel): def __qeff_init__(self): self.rotary_emb = QEffGemmaRotaryEmbedding(config=self.config) + QEffGemmaRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 9500a7253f..2a779c4d40 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -40,6 +40,8 @@ class QEffGemma2RotaryEmbedding(Gemma2RotaryEmbedding): - Add static sin/cos computations. """ + _max_seq_len_cached = 0 + def __init__(self, config: Gemma2Config, device=None): super().__init__(config=config) @@ -146,6 +148,10 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + if kv_seq_len > QEffGemma2RotaryEmbedding._max_seq_len_cached: + QEffGemma2RotaryEmbedding._set_cos_sin_cache( + seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype + ) cos_cached[:kv_seq_len].to(dtype=value_states.dtype) sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached @@ -268,6 +274,7 @@ class QEffGemma2Model(Gemma2Model): def __qeff_init__(self): self.rotary_emb = QEffGemma2RotaryEmbedding(config=self.config) + QEffGemma2RotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 7884049614..5b150ae55a 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -510,6 +510,8 @@ class QEffGptOssRotaryEmbedding(GptOssRotaryEmbedding): - Add static sin/cos computations. """ + _max_seq_len_cached = 0 + def __init__(self, config: GptOssConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. @@ -750,6 +752,11 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): max_seq_len_cached = 32 * 1024 + + if max_seq_len_cached > QEffGptOssRotaryEmbedding._max_seq_len_cached: + QEffGptOssRotaryEmbedding._set_cos_sin_cache( + seq_len=max_seq_len_cached, device=value_states.device, dtype=value_states.dtype + ) cos_cached[:max_seq_len_cached].to(dtype=value_states.dtype) sin_cached[:max_seq_len_cached].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached @@ -1124,6 +1131,7 @@ def forward( class QEffGptOssModel(GptOssModel): def __qeff_init__(self): self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + QEffGptOssRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index b1724d92e6..a565af999f 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -37,6 +37,8 @@ class QEffGraniteRotaryEmbedding(GraniteRotaryEmbedding): - Add static sin/cos computations. """ + _max_seq_len_cached = 0 + def __init__(self, config: GraniteConfig, device=None): super().__init__(config=config) self._set_cos_sin_cache( @@ -132,6 +134,10 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + if kv_seq_len > QEffGraniteRotaryEmbedding._max_seq_len_cached: + QEffGraniteRotaryEmbedding._set_cos_sin_cache( + seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype + ) cos_cached[:kv_seq_len].to(dtype=value_states.dtype) sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached @@ -246,6 +252,7 @@ def forward( class QEffGraniteModel(GraniteModel): def __qeff_init__(self): self.rotary_emb = QEffGraniteRotaryEmbedding(config=self.config) + QEffGraniteRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index d13504f8d4..2bccd1250a 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -38,6 +38,8 @@ class QEffGraniteMoeRotaryEmbedding(GraniteMoeRotaryEmbedding): - Add static sin/cos computations. """ + _max_seq_len_cached = 0 + def __init__( self, config: Optional[GraniteMoeConfig] = None, @@ -127,6 +129,10 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + if kv_seq_len > QEffGraniteMoeRotaryEmbedding._max_seq_len_cached: + QEffGraniteMoeRotaryEmbedding._set_cos_sin_cache( + seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype + ) cos_cached[:kv_seq_len].to(dtype=value_states.dtype) sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached @@ -284,6 +290,7 @@ class QEffGraniteMoeModel(GraniteMoeModel): def __qeff_init__(self): self.rotary_emb = QEffGraniteMoeRotaryEmbedding(config=self.config) + QEffGraniteMoeRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index d4c4903000..b44473af84 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -37,6 +37,8 @@ class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding): - Add static sin/cos computations. """ + _max_seq_len_cached = 0 + def __init__(self, config: LlamaConfig, device=None): super().__init__(config=config) @@ -217,6 +219,10 @@ def forward( kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 + if kv_seq_len > QEffLlamaRotaryEmbedding._max_seq_len_cached: + QEffLlamaRotaryEmbedding._set_cos_sin_cache( + seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype + ) cos_cached[:kv_seq_len].to(dtype=value_states.dtype) sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached @@ -318,6 +324,7 @@ class QEffLlamaModel(LlamaModel): def __qeff_init__(self): self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) + QEffLlamaRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 3a04e98004..21e8306eb9 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -113,6 +113,10 @@ def forward( kv_seq_len = past_key_value.get_seq_length(self.layer_idx) key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs) + if kv_seq_len > QEffLlamaRotaryEmbedding._max_seq_len_cached: + QEffLlamaRotaryEmbedding._set_cos_sin_cache( + seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype + ) cos_cached[:kv_seq_len].to(dtype=value_states.dtype) sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached @@ -214,6 +218,7 @@ def __init__(self, config: QEffLlamaSwiftKVConfig): def __qeff_init__(self): self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) + QEffLlamaRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) @@ -388,6 +393,10 @@ def forward( ) kv_seq_len = past_key_values.get_seq_length(self_attn.layer_idx) + if kv_seq_len > QEffLlamaRotaryEmbedding._max_seq_len_cached: + QEffLlamaRotaryEmbedding.rotary_emb._set_cos_sin_cache( + seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype + ) cos, sin = self.cos_cached[:kv_seq_len], self.sin_cached[:kv_seq_len] _, key_states = qeff_apply_rotary_pos_emb(torch.empty_like(key_states), key_states, cos, sin, position_ids) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index f254aa0f0a..de462457aa 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -40,6 +40,8 @@ class QEffMistralRotaryEmbedding(MistralRotaryEmbedding): - Add static sin/cos computations. """ + _max_seq_len_cached = 0 + def __init__(self, config: MistralConfig, device=None): super().__init__(config=config) @@ -149,6 +151,10 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + if kv_seq_len > QEffMistralRotaryEmbedding._max_seq_len_cached: + QEffMistralRotaryEmbedding._set_cos_sin_cache( + seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype + ) cos_cached[:kv_seq_len].to(dtype=value_states.dtype) sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached @@ -253,6 +259,7 @@ class QEffMistralModel(MistralModel): def __qeff_init__(self): self.rotary_emb = QEffMistralRotaryEmbedding(config=self.config) + QEffMistralRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index a27f6d2462..b4c65b41da 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -42,6 +42,8 @@ class QEffMixtralRotaryEmbedding(MixtralRotaryEmbedding): - Add static sin/cos computations. """ + _max_seq_len_cached = 0 + def __init__(self, config: MixtralConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. @@ -146,6 +148,10 @@ def forward( ) kv_seq_len = past_key_value.get_seq_length(self.layer_idx) + if kv_seq_len > QEffMixtralRotaryEmbedding._max_seq_len_cached: + QEffMixtralRotaryEmbedding._set_cos_sin_cache( + seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype + ) cos_cached[:kv_seq_len].to(dtype=value_states.dtype) sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached @@ -321,6 +327,7 @@ class QEffMixtralModel(MixtralModel): def __qeff_init__(self): self.rotary_emb = QEffMixtralRotaryEmbedding(config=self.config) + QEffMixtralRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 94de5faa0d..204613bc29 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -103,6 +103,8 @@ class QEffMllamaRotaryEmbedding(MllamaRotaryEmbedding): - Add static sin/cos computations. """ + _max_seq_len_cached = 0 + def __init__(self, config: MllamaConfig, device=None): super().__init__(config=config) @@ -265,6 +267,10 @@ def forward( ) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + if kv_seq_len > QEffMllamaRotaryEmbedding._max_seq_len_cached: + QEffMllamaRotaryEmbedding._set_cos_sin_cache( + seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype + ) cos_cached[:kv_seq_len].to(dtype=value_states.dtype) sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached @@ -601,6 +607,7 @@ class QEffMllamaTextModel(MllamaTextModel): def __qeff_init__(self): self.rotary_emb = QEffMllamaRotaryEmbedding(config=self.config) + QEffMllamaRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index d89ce47383..4f84aa4960 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -37,6 +37,8 @@ class QEffOlmo2RotaryEmbedding(Olmo2RotaryEmbedding): - Add static sin/cos computations. """ + _max_seq_len_cached = 0 + def __init__(self, config: Olmo2Config, device=None): super().__init__(config=config) @@ -140,6 +142,10 @@ def forward( kv_seq_len = key_states.shape[-2] kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + if kv_seq_len > QEffOlmo2RotaryEmbedding._max_seq_len_cached: + QEffOlmo2RotaryEmbedding._set_cos_sin_cache( + seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype + ) cos_cached[:kv_seq_len].to(dtype=value_states.dtype) sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached @@ -230,6 +236,7 @@ class QEffOlmo2Model(Olmo2Model): def __qeff_init__(self): self.rotary_emb = QEffOlmo2RotaryEmbedding(config=self.config) + QEffOlmo2RotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index 4fa05fa8bc..8932452177 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -37,6 +37,8 @@ class QEffPhi3RotaryEmbedding(Phi3RotaryEmbedding): - Add static sin/cos computations. """ + _max_seq_len_cached = 0 + def __init__(self, config: Phi3Config, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. @@ -147,6 +149,10 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + if kv_seq_len > QEffPhi3RotaryEmbedding._max_seq_len_cached: + QEffPhi3RotaryEmbedding._set_cos_sin_cache( + seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype + ) cos_cached[:kv_seq_len].to(dtype=value_states.dtype) sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached @@ -262,6 +268,7 @@ class QEffPhi3Model(Phi3Model): def __qeff_init__(self): self.rotary_emb = QEffPhi3RotaryEmbedding(config=self.config) + QEffPhi3RotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 85575196c1..e511d60f6d 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -41,6 +41,8 @@ class QEffQwen2RotaryEmbedding(Qwen2RotaryEmbedding): - Add static sin/cos computations. """ + _max_seq_len_cached = 0 + def __init__(self, config: Qwen2Config, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. @@ -152,6 +154,10 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + if kv_seq_len > QEffQwen2RotaryEmbedding._max_seq_len_cached: + QEffQwen2RotaryEmbedding._set_cos_sin_cache( + seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype + ) cos_cached[:kv_seq_len].to(dtype=value_states.dtype) sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached @@ -258,6 +264,7 @@ class QEffQwen2Model(Qwen2Model): def __qeff_init__(self): self.rotary_emb = QEffQwen2RotaryEmbedding(config=self.config) + QEffQwen2RotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 9f7cbde56e..3523488c33 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -331,6 +331,8 @@ class QEffQwen2_5_VLRotaryEmbedding(Qwen2_5_VLRotaryEmbedding): - Add static sin/cos computations. """ + _max_seq_len_cached = 0 + def __init__(self, config: Qwen2_5_VLConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. @@ -584,6 +586,10 @@ def forward( kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 + if kv_seq_len > QEffQwen2_5_VLRotaryEmbedding._max_seq_len_cached: + QEffQwen2_5_VLRotaryEmbedding._set_cos_sin_cache( + seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype + ) cos_cached[:kv_seq_len].to(dtype=value_states.dtype) sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached @@ -720,6 +726,7 @@ def forward( class QEffQwen2_5_VLTextModel(Qwen2_5_VLTextModel): def __qeff_init__(self): self.rotary_emb = QEffQwen2_5_VLRotaryEmbedding(config=self.config) + QEffQwen2_5_VLRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index bcaa75dd4c..0a097edef8 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -41,6 +41,8 @@ class QEffQwen3RotaryEmbedding(Qwen3RotaryEmbedding): - Add static sin/cos computations. """ + _max_seq_len_cached = 0 + def __init__(self, config: Qwen3Config, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. @@ -153,6 +155,10 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + if kv_seq_len > QEffQwen3RotaryEmbedding._max_seq_len_cached: + QEffQwen3RotaryEmbedding._set_cos_sin_cache( + seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype + ) cos_cached[:kv_seq_len].to(dtype=value_states.dtype) sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached @@ -259,6 +265,7 @@ class QEffQwen3Model(Qwen3Model): def __qeff_init__(self): self.rotary_emb = QEffQwen3RotaryEmbedding(config=self.config) + QEffQwen3RotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 3a1f5dfa3d..6e5b372b1d 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -32,6 +32,8 @@ class QEffQwen3MoeRotaryEmbedding(Qwen3MoeRotaryEmbedding): + _max_seq_len_cached = 0 + def __init__(self, config: Qwen3MoeConfig, device=None): super().__init__(config=config) @@ -199,6 +201,10 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + if kv_seq_len > QEffQwen3MoeRotaryEmbedding._max_seq_len_cached: + QEffQwen3MoeRotaryEmbedding._set_cos_sin_cache( + seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype + ) cos_cached[:kv_seq_len].to(dtype=value_states.dtype) sin_cached[:kv_seq_len].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached @@ -293,6 +299,7 @@ def forward( class QEffQwen3MoeModel(Qwen3MoeModel): def __qeff_init__(self): self.rotary_emb = QEffQwen3MoeRotaryEmbedding(config=self.config) + QEffQwen3MoeRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.rotary_emb._set_cos_sin_cache( seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype ) diff --git a/tests/transformers/models/test_single_subfunction.py b/tests/transformers/models/test_single_subfunction.py index f17edab654..0d2fc3bc64 100644 --- a/tests/transformers/models/test_single_subfunction.py +++ b/tests/transformers/models/test_single_subfunction.py @@ -16,26 +16,26 @@ torch.manual_seed(42) configs = [ - ("gpt2", 256, 2, 4, 128, 512, 127, {}), - # ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + # ("gpt2", 256, 2, 4, 128, 512, 127, {}), + # # ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), ("falcon", 256, 2, 4, 128, 512, 127, {}), - ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), - ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("mpt", 256, 2, 4, 128, 512, 127, {}), - # ("phi", 256, 2, 4, 128, 512, 127, {}), - ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), - ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("qwen3", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("starcoder2", 256, 2, 4, 128, 512, 127, {}), - ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("qwen3_moe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("granitemoe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("gemma", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("gemma2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + # ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # # ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("mpt", 256, 2, 4, 128, 512, 127, {}), + # # ("phi", 256, 2, 4, 128, 512, 127, {}), + # ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), + # ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("qwen3", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # # ("starcoder2", 256, 2, 4, 128, 512, 127, {}), + # ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # # ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("qwen3_moe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # # ("granitemoe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("gemma", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("gemma2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] configs = [ From f6adeaa1861e0a89dd67dcc72ed49bacea479452 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Wed, 25 Mar 2026 09:35:48 +0000 Subject: [PATCH 5/9] simplifying Signed-off-by: Onkar Chougule --- .../models/gpt_oss/modeling_gpt_oss.py | 31 +++++++------------ 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 5b150ae55a..85ae4e807d 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -510,8 +510,6 @@ class QEffGptOssRotaryEmbedding(GptOssRotaryEmbedding): - Add static sin/cos computations. """ - _max_seq_len_cached = 0 - def __init__(self, config: GptOssConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. @@ -750,17 +748,19 @@ def forward( hidden_shape = (*input_shape, -1, self.head_dim) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): - max_seq_len_cached = 32 * 1024 + # if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + # max_seq_len_cached = 32 * 1024 - if max_seq_len_cached > QEffGptOssRotaryEmbedding._max_seq_len_cached: - QEffGptOssRotaryEmbedding._set_cos_sin_cache( - seq_len=max_seq_len_cached, device=value_states.device, dtype=value_states.dtype - ) - cos_cached[:max_seq_len_cached].to(dtype=value_states.dtype) - sin_cached[:max_seq_len_cached].to(dtype=value_states.dtype) + # if max_seq_len_cached > QEffGptOssRotaryEmbedding._max_seq_len_cached: + # QEffGptOssRotaryEmbedding._set_cos_sin_cache( + # seq_len=max_seq_len_cached, device=value_states.device, dtype=value_states.dtype + # ) + # cos_cached[:max_seq_len_cached].to(dtype=value_states.dtype) + # sin_cached[:max_seq_len_cached].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -925,8 +925,6 @@ def forward( query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): - # max_seq_len_cached = 32 * 1024 cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) cos, sin = cos_cached, sin_cached @@ -1030,9 +1028,6 @@ def forward( class QEffPrefillOnlyGptOssModel(GptOssModel): def __qeff_init__(self): self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) @@ -1131,10 +1126,6 @@ def forward( class QEffGptOssModel(GptOssModel): def __qeff_init__(self): self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) - QEffGptOssRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) From fc6f938cddf1572bc2bc5950d6e668d98dfb1bb6 Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Wed, 25 Mar 2026 15:33:56 +0000 Subject: [PATCH 6/9] simplified modeling file Signed-off-by: abhishek-singh591 --- .../models/falcon/modeling_falcon.py | 16 +-- .../models/gemma/modeling_gemma.py | 19 +-- .../models/gemma2/modeling_gemma2.py | 23 +-- .../models/gpt_oss/modeling_gpt_oss.py | 38 ++--- .../models/granite/modeling_granite.py | 23 +-- .../models/granitemoe/modeling_granitemoe.py | 23 +-- .../models/llama/modeling_llama.py | 19 +-- .../llama_swiftkv/modeling_llama_swiftkv.py | 26 +--- .../models/mistral/modeling_mistral.py | 19 +-- .../models/mixtral_moe/modeling_mixtral.py | 19 +-- .../models/mllama/modeling_mllama.py | 19 +-- .../models/olmo2/modeling_olmo2.py | 22 +-- .../transformers/models/phi3/modeling_phi3.py | 19 +-- .../models/qwen2/modeling_qwen2.py | 19 +-- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 28 +--- .../models/qwen3/modeling_qwen3.py | 19 +-- .../models/qwen3_moe/modeling_qwen3_moe.py | 19 +-- examples/basic_vlm_inference.py | 135 ++++++++++++++++++ test.py | 13 ++ 19 files changed, 234 insertions(+), 284 deletions(-) create mode 100644 examples/basic_vlm_inference.py create mode 100644 test.py diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 4278e7f87d..e70f32818f 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -42,8 +42,6 @@ class QEffFalconRotaryEmbedding(FalconRotaryEmbedding): - Add static sin/cos computations. """ - _max_seq_len_cached = 0 - def __init__(self, config: FalconConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. @@ -128,14 +126,8 @@ def forward( key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - - if kv_seq_len > QEffFalconRotaryEmbedding._max_seq_len_cached: - self._set_cos_sin_cache(seq_len=kv_seq_len, device=value_layer.device, dtype=value_layer.dtype) - cos_cached[:kv_seq_len].to(dtype=value_layer.dtype) - sin_cached[:kv_seq_len].to(dtype=value_layer.dtype) - cos, sin = cos_cached, sin_cached - query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos_cached, sin_cached, position_ids) if layer_past is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -247,10 +239,6 @@ class QEffFalconModel(FalconModel): def __qeff_init__(self): self.rotary_emb = QEffFalconRotaryEmbedding(config=self.config) - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) - QEffFalconRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index badc6ee3a0..3bed2d00ee 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -37,8 +37,6 @@ class QEffGemmaRotaryEmbedding(GemmaRotaryEmbedding): - Add static sin/cos computations. """ - _max_seq_len_cached = 0 - def __init__(self, config: GemmaConfig, device=None): super().__init__(config=config) @@ -140,15 +138,10 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - if kv_seq_len > QEffGemmaRotaryEmbedding._max_seq_len_cached: - QEffGemmaRotaryEmbedding._set_cos_sin_cache( - seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype - ) - cos_cached[:kv_seq_len].to(dtype=value_states.dtype) - sin_cached[:kv_seq_len].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -246,10 +239,6 @@ class QEffGemmaModel(GemmaModel): def __qeff_init__(self): self.rotary_emb = QEffGemmaRotaryEmbedding(config=self.config) - QEffGemmaRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 2a779c4d40..8e2e823c7f 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -40,8 +40,6 @@ class QEffGemma2RotaryEmbedding(Gemma2RotaryEmbedding): - Add static sin/cos computations. """ - _max_seq_len_cached = 0 - def __init__(self, config: Gemma2Config, device=None): super().__init__(config=config) @@ -147,21 +145,16 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - if kv_seq_len > QEffGemma2RotaryEmbedding._max_seq_len_cached: - QEffGemma2RotaryEmbedding._set_cos_sin_cache( - seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype - ) - cos_cached[:kv_seq_len].to(dtype=value_states.dtype) - sin_cached[:kv_seq_len].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { - "sin": sin, - "cos": cos, + "sin": sin_cached, + "cos": cos_cached, "batch_index": batch_index, "position_ids": position_ids, } @@ -274,10 +267,6 @@ class QEffGemma2Model(Gemma2Model): def __qeff_init__(self): self.rotary_emb = QEffGemma2RotaryEmbedding(config=self.config) - QEffGemma2RotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 85ae4e807d..6f4e1d8c43 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -748,16 +748,6 @@ def forward( hidden_shape = (*input_shape, -1, self.head_dim) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): - # max_seq_len_cached = 32 * 1024 - - # if max_seq_len_cached > QEffGptOssRotaryEmbedding._max_seq_len_cached: - # QEffGptOssRotaryEmbedding._set_cos_sin_cache( - # seq_len=max_seq_len_cached, device=value_states.device, dtype=value_states.dtype - # ) - # cos_cached[:max_seq_len_cached].to(dtype=value_states.dtype) - # sin_cached[:max_seq_len_cached].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached query_states, key_states = qeff_apply_rotary_pos_emb( query_states, key_states, cos_cached, sin_cached, position_ids ) @@ -765,8 +755,8 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { - "sin": sin, - "cos": cos, + "sin": sin_cached, + "cos": cos_cached, "batch_index": batch_index, "position_ids": position_ids, "config": self.config, @@ -842,18 +832,15 @@ def forward( hidden_shape = (*input_shape, -1, self.head_dim) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): - # max_seq_len_cached = 32 * 1024 - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { - "sin": sin, - "cos": cos, + "sin": sin_cached, + "cos": cos_cached, "batch_index": batch_index, "position_ids": position_ids, "config": self.config, @@ -925,16 +912,15 @@ def forward( query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - cos_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - sin_cached[: hidden_states.shape[1]].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { - "sin": sin, - "cos": cos, + "sin": sin_cached, + "cos": cos_cached, "batch_index": batch_index, "position_ids": position_ids, "config": self.config, diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index a565af999f..c2af97f55d 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -37,8 +37,6 @@ class QEffGraniteRotaryEmbedding(GraniteRotaryEmbedding): - Add static sin/cos computations. """ - _max_seq_len_cached = 0 - def __init__(self, config: GraniteConfig, device=None): super().__init__(config=config) self._set_cos_sin_cache( @@ -133,21 +131,16 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - if kv_seq_len > QEffGraniteRotaryEmbedding._max_seq_len_cached: - QEffGraniteRotaryEmbedding._set_cos_sin_cache( - seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype - ) - cos_cached[:kv_seq_len].to(dtype=value_states.dtype) - sin_cached[:kv_seq_len].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { - "sin": sin, - "cos": cos, + "sin": sin_cached, + "cos": cos_cached, "batch_index": batch_index, "position_ids": position_ids, } @@ -252,10 +245,6 @@ def forward( class QEffGraniteModel(GraniteModel): def __qeff_init__(self): self.rotary_emb = QEffGraniteRotaryEmbedding(config=self.config) - QEffGraniteRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 2bccd1250a..82bb8533a6 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -38,8 +38,6 @@ class QEffGraniteMoeRotaryEmbedding(GraniteMoeRotaryEmbedding): - Add static sin/cos computations. """ - _max_seq_len_cached = 0 - def __init__( self, config: Optional[GraniteMoeConfig] = None, @@ -128,20 +126,15 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - if kv_seq_len > QEffGraniteMoeRotaryEmbedding._max_seq_len_cached: - QEffGraniteMoeRotaryEmbedding._set_cos_sin_cache( - seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype - ) - cos_cached[:kv_seq_len].to(dtype=value_states.dtype) - sin_cached[:kv_seq_len].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { - "sin": sin, - "cos": cos, + "sin": sin_cached, + "cos": cos_cached, "cache_position": cache_position, "batch_index": batch_index, "position_ids": position_ids, @@ -290,10 +283,6 @@ class QEffGraniteMoeModel(GraniteMoeModel): def __qeff_init__(self): self.rotary_emb = QEffGraniteMoeRotaryEmbedding(config=self.config) - QEffGraniteMoeRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index b44473af84..5b501d36fa 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -37,8 +37,6 @@ class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding): - Add static sin/cos computations. """ - _max_seq_len_cached = 0 - def __init__(self, config: LlamaConfig, device=None): super().__init__(config=config) @@ -217,16 +215,11 @@ def forward( key_states = self.k_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 - if kv_seq_len > QEffLlamaRotaryEmbedding._max_seq_len_cached: - QEffLlamaRotaryEmbedding._set_cos_sin_cache( - seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype - ) - cos_cached[:kv_seq_len].to(dtype=value_states.dtype) - sin_cached[:kv_seq_len].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: if num_kv_blocks is not None: @@ -324,10 +317,6 @@ class QEffLlamaModel(LlamaModel): def __qeff_init__(self): self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) - QEffLlamaRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 21e8306eb9..537d358a38 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -110,19 +110,12 @@ def forward( if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - kv_seq_len = past_key_value.get_seq_length(self.layer_idx) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx) key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs) - if kv_seq_len > QEffLlamaRotaryEmbedding._max_seq_len_cached: - QEffLlamaRotaryEmbedding._set_cos_sin_cache( - seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype - ) - cos_cached[:kv_seq_len].to(dtype=value_states.dtype) - sin_cached[:kv_seq_len].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached position_ids = position_ids[torch.arange(bsz), position_ids.to(torch.int32).argmax(1)].unsqueeze(1) query_states, _ = qeff_apply_rotary_pos_emb( - query_states, torch.empty_like(query_states), cos, sin, position_ids + query_states, torch.empty_like(query_states), cos_cached, sin_cached, position_ids ) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -218,10 +211,6 @@ def __init__(self, config: QEffLlamaSwiftKVConfig): def __qeff_init__(self): self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) - QEffLlamaRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) @@ -391,14 +380,11 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_values.get_seq_length(self_attn.layer_idx) + # kv_seq_len = past_key_values.get_seq_length(self_attn.layer_idx) - if kv_seq_len > QEffLlamaRotaryEmbedding._max_seq_len_cached: - QEffLlamaRotaryEmbedding.rotary_emb._set_cos_sin_cache( - seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype - ) - cos, sin = self.cos_cached[:kv_seq_len], self.sin_cached[:kv_seq_len] - _, key_states = qeff_apply_rotary_pos_emb(torch.empty_like(key_states), key_states, cos, sin, position_ids) + _, key_states = qeff_apply_rotary_pos_emb( + torch.empty_like(key_states), key_states, self.cos_cached, self.sin_cached, position_ids + ) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} past_key_values.write_only(key_states, value_states, self_attn.layer_idx, cache_kwargs) diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index de462457aa..76a7d24c64 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -40,8 +40,6 @@ class QEffMistralRotaryEmbedding(MistralRotaryEmbedding): - Add static sin/cos computations. """ - _max_seq_len_cached = 0 - def __init__(self, config: MistralConfig, device=None): super().__init__(config=config) @@ -150,15 +148,10 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - if kv_seq_len > QEffMistralRotaryEmbedding._max_seq_len_cached: - QEffMistralRotaryEmbedding._set_cos_sin_cache( - seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype - ) - cos_cached[:kv_seq_len].to(dtype=value_states.dtype) - sin_cached[:kv_seq_len].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -259,10 +252,6 @@ class QEffMistralModel(MistralModel): def __qeff_init__(self): self.rotary_emb = QEffMistralRotaryEmbedding(config=self.config) - QEffMistralRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index b4c65b41da..e59a3be534 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -42,8 +42,6 @@ class QEffMixtralRotaryEmbedding(MixtralRotaryEmbedding): - Add static sin/cos computations. """ - _max_seq_len_cached = 0 - def __init__(self, config: MixtralConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. @@ -146,16 +144,11 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx) - if kv_seq_len > QEffMixtralRotaryEmbedding._max_seq_len_cached: - QEffMixtralRotaryEmbedding._set_cos_sin_cache( - seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype - ) - cos_cached[:kv_seq_len].to(dtype=value_states.dtype) - sin_cached[:kv_seq_len].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -327,10 +320,6 @@ class QEffMixtralModel(MixtralModel): def __qeff_init__(self): self.rotary_emb = QEffMixtralRotaryEmbedding(config=self.config) - QEffMixtralRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 204613bc29..642fb4bb7c 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -103,8 +103,6 @@ class QEffMllamaRotaryEmbedding(MllamaRotaryEmbedding): - Add static sin/cos computations. """ - _max_seq_len_cached = 0 - def __init__(self, config: MllamaConfig, device=None): super().__init__(config=config) @@ -265,16 +263,11 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - if kv_seq_len > QEffMllamaRotaryEmbedding._max_seq_len_cached: - QEffMllamaRotaryEmbedding._set_cos_sin_cache( - seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype - ) - cos_cached[:kv_seq_len].to(dtype=value_states.dtype) - sin_cached[:kv_seq_len].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: cache_kwargs = { @@ -607,10 +600,6 @@ class QEffMllamaTextModel(MllamaTextModel): def __qeff_init__(self): self.rotary_emb = QEffMllamaRotaryEmbedding(config=self.config) - QEffMllamaRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index 4f84aa4960..22834d2926 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -37,8 +37,6 @@ class QEffOlmo2RotaryEmbedding(Olmo2RotaryEmbedding): - Add static sin/cos computations. """ - _max_seq_len_cached = 0 - def __init__(self, config: Olmo2Config, device=None): super().__init__(config=config) @@ -139,17 +137,11 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - if kv_seq_len > QEffOlmo2RotaryEmbedding._max_seq_len_cached: - QEffOlmo2RotaryEmbedding._set_cos_sin_cache( - seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype - ) - cos_cached[:kv_seq_len].to(dtype=value_states.dtype) - sin_cached[:kv_seq_len].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = key_states.shape[-2] + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -236,10 +228,6 @@ class QEffOlmo2Model(Olmo2Model): def __qeff_init__(self): self.rotary_emb = QEffOlmo2RotaryEmbedding(config=self.config) - QEffOlmo2RotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index 8932452177..b18dbcd5c0 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -37,8 +37,6 @@ class QEffPhi3RotaryEmbedding(Phi3RotaryEmbedding): - Add static sin/cos computations. """ - _max_seq_len_cached = 0 - def __init__(self, config: Phi3Config, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. @@ -148,16 +146,11 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - if kv_seq_len > QEffPhi3RotaryEmbedding._max_seq_len_cached: - QEffPhi3RotaryEmbedding._set_cos_sin_cache( - seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype - ) - cos_cached[:kv_seq_len].to(dtype=value_states.dtype) - sin_cached[:kv_seq_len].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: cache_kwargs = { @@ -268,10 +261,6 @@ class QEffPhi3Model(Phi3Model): def __qeff_init__(self): self.rotary_emb = QEffPhi3RotaryEmbedding(config=self.config) - QEffPhi3RotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index e511d60f6d..df7421c466 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -41,8 +41,6 @@ class QEffQwen2RotaryEmbedding(Qwen2RotaryEmbedding): - Add static sin/cos computations. """ - _max_seq_len_cached = 0 - def __init__(self, config: Qwen2Config, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. @@ -153,15 +151,10 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - if kv_seq_len > QEffQwen2RotaryEmbedding._max_seq_len_cached: - QEffQwen2RotaryEmbedding._set_cos_sin_cache( - seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype - ) - cos_cached[:kv_seq_len].to(dtype=value_states.dtype) - sin_cached[:kv_seq_len].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -264,10 +257,6 @@ class QEffQwen2Model(Qwen2Model): def __qeff_init__(self): self.rotary_emb = QEffQwen2RotaryEmbedding(config=self.config) - QEffQwen2RotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 3523488c33..f333302bc0 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -331,8 +331,6 @@ class QEffQwen2_5_VLRotaryEmbedding(Qwen2_5_VLRotaryEmbedding): - Add static sin/cos computations. """ - _max_seq_len_cached = 0 - def __init__(self, config: Qwen2_5_VLConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. @@ -582,27 +580,19 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + # kv_seq_len = key_states.shape[-2] + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 - if kv_seq_len > QEffQwen2_5_VLRotaryEmbedding._max_seq_len_cached: - QEffQwen2_5_VLRotaryEmbedding._set_cos_sin_cache( - seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype - ) - cos_cached[:kv_seq_len].to(dtype=value_states.dtype) - sin_cached[:kv_seq_len].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids[1:], self.rope_scaling["mrope_section"] + query_states, key_states, cos_cached, sin_cached, position_ids[1:], self.rope_scaling["mrope_section"] ) if past_key_value is not None: if num_kv_blocks is not None: cache_kwargs = { - "sin": sin, - "cos": cos, + "sin": sin_cached, + "cos": cos_cached, "batch_index": batch_index, "position_ids": position_ids[0], "past_seen_tokens": past_seen_tokens, @@ -611,8 +601,8 @@ def forward( else: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { - "sin": sin, - "cos": cos, + "sin": sin_cached, + "cos": cos_cached, "batch_index": batch_index, "position_ids": position_ids[0], } @@ -726,10 +716,6 @@ def forward( class QEffQwen2_5_VLTextModel(Qwen2_5_VLTextModel): def __qeff_init__(self): self.rotary_emb = QEffQwen2_5_VLRotaryEmbedding(config=self.config) - QEffQwen2_5_VLRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index 0a097edef8..4202f52e18 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -41,8 +41,6 @@ class QEffQwen3RotaryEmbedding(Qwen3RotaryEmbedding): - Add static sin/cos computations. """ - _max_seq_len_cached = 0 - def __init__(self, config: Qwen3Config, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. @@ -154,15 +152,10 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - if kv_seq_len > QEffQwen3RotaryEmbedding._max_seq_len_cached: - QEffQwen3RotaryEmbedding._set_cos_sin_cache( - seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype - ) - cos_cached[:kv_seq_len].to(dtype=value_states.dtype) - sin_cached[:kv_seq_len].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -265,10 +258,6 @@ class QEffQwen3Model(Qwen3Model): def __qeff_init__(self): self.rotary_emb = QEffQwen3RotaryEmbedding(config=self.config) - QEffQwen3RotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 6e5b372b1d..fb7320ff6a 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -32,8 +32,6 @@ class QEffQwen3MoeRotaryEmbedding(Qwen3MoeRotaryEmbedding): - _max_seq_len_cached = 0 - def __init__(self, config: Qwen3MoeConfig, device=None): super().__init__(config=config) @@ -200,15 +198,10 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - if kv_seq_len > QEffQwen3MoeRotaryEmbedding._max_seq_len_cached: - QEffQwen3MoeRotaryEmbedding._set_cos_sin_cache( - seq_len=kv_seq_len, device=value_states.device, dtype=value_states.dtype - ) - cos_cached[:kv_seq_len].to(dtype=value_states.dtype) - sin_cached[:kv_seq_len].to(dtype=value_states.dtype) - cos, sin = cos_cached, sin_cached - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -299,10 +292,6 @@ def forward( class QEffQwen3MoeModel(Qwen3MoeModel): def __qeff_init__(self): self.rotary_emb = QEffQwen3MoeRotaryEmbedding(config=self.config) - QEffQwen3MoeRotaryEmbedding._max_seq_len_cached = self.config.max_position_embeddings - self.rotary_emb._set_cos_sin_cache( - seq_len=self.config.max_position_embeddings, device=self.device, dtype=self.dtype - ) self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) diff --git a/examples/basic_vlm_inference.py b/examples/basic_vlm_inference.py new file mode 100644 index 0000000000..c6008afbb6 --- /dev/null +++ b/examples/basic_vlm_inference.py @@ -0,0 +1,135 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import argparse + +import requests +from PIL import Image +from transformers import AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + + +def run_model( + model_name, + query, + image_url, + kv_offload=True, + prefill_seq_len=32, + ctx_len=512, + generation_len=128, + img_size=336, + num_cores=16, + num_devices=1, +): + ## STEP 1: Load the Processor and Model + + processor = AutoProcessor.from_pretrained(model_name) + + # `kv_offload` determines Single QPC vs Dual QPC mode: + # - Single QPC (kv_offload=False): Entire model runs in one QPC + # - Dual QPC (kv_offload=True): Vision encoder and language model run in separate QPCs + # with outputs transferred via host for flexibility + + model = QEFFAutoModelForImageTextToText.from_pretrained( + model_name, attn_implementation="eager", kv_offload=kv_offload + ) + + ## STEP 2: Export & Compile the Model + + model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=False, + use_onnx_subfunctions=True, + ) + + ## STEP 3: Load and Process the Inputs for Inference + # Note: the message format would change for different model + image = Image.open(requests.get(image_url, stream=True).raw) + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": query}, + ], + } + ] + input_text = [processor.apply_chat_template(messages, add_generation_prompt=True)] + + inputs = processor( + text=input_text, + images=image, + return_tensors="pt", + add_special_tokens=False, + padding="max_length", + max_length=prefill_seq_len, + ) + + ## STEP 4: Run Inference on the Compiled Model + + streamer = TextStreamer(processor.tokenizer) + model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) + + +def main(): + parser = argparse.ArgumentParser(description="Vision-Language Model (VLM) inference") + parser.add_argument( + "--model-name", + type=str, + default="Qwen/Qwen2.5-VL-3B-Instruct", + help="HuggingFace VLM model ID", + ) + parser.add_argument( + "--query", + type=str, + default="Describe this image.", + help="Text query/question about the image", + ) + parser.add_argument( + "--image-url", + type=str, + default="https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + help="URL of the image to process", + ) + parser.add_argument( + "--kv-offload", + action="store_true", + default=True, + help="Enable Dual QPC mode (vision encoder and LM in separate QPCs)", + ) + parser.add_argument("--prefill-seq-len", type=int, default=128, help="Prefill sequence length") + parser.add_argument("--ctx-len", type=int, default=3000, help="Context length") + parser.add_argument("--generation-len", type=int, default=128, help="Number of tokens to generate") + parser.add_argument("--img-size", type=int, default=336, help="Image size for processing") + parser.add_argument("--num-cores", type=int, default=16, help="Number of cores") + parser.add_argument("--num-devices", type=int, default=1, help="Number of devices") + args = parser.parse_args() + + print(f"Running VLM inference with model: {args.model_name}") + print(f"KV offload (Dual QPC mode): {args.kv_offload}") + + run_model( + model_name=args.model_name, + query=args.query, + image_url=args.image_url, + kv_offload=args.kv_offload, + prefill_seq_len=args.prefill_seq_len, + ctx_len=args.ctx_len, + generation_len=args.generation_len, + img_size=args.img_size, + num_cores=args.num_cores, + num_devices=args.num_devices, + ) + + +if __name__ == "__main__": + main() diff --git a/test.py b/test.py new file mode 100644 index 0000000000..fc05dfa34a --- /dev/null +++ b/test.py @@ -0,0 +1,13 @@ +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + +model_name = "meta-llama/Llama-3.2-1B" +model1 = QEFFAutoModelForCausalLM.from_pretrained(model_name) +# model2=QEFFAutoModelForCausalLM.from_pretrained(model_name, num_hidden_layers = 2) +# with_sub_func_onnx = model1.export(use_onnx_subfunctions=True) +model1.compile(num_devices=1, num_cores=16, use_onnx_subfunctions=True) +hash_0_1 = model1.export_hash +inputs = "Help me with this" +tokenizer = AutoTokenizer.from_pretrained(model_name) +generation_00 = model1.generate(prompts=["Help me with this"], tokenizer=tokenizer) \ No newline at end of file From a45276d882506e2ab63522957bee5b6437dcffc6 Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Wed, 25 Mar 2026 15:34:51 +0000 Subject: [PATCH 7/9] simplified modeling file Signed-off-by: abhishek-singh591 --- examples/basic_vlm_inference.py | 135 -------------------------------- test.py | 13 --- 2 files changed, 148 deletions(-) delete mode 100644 examples/basic_vlm_inference.py delete mode 100644 test.py diff --git a/examples/basic_vlm_inference.py b/examples/basic_vlm_inference.py deleted file mode 100644 index c6008afbb6..0000000000 --- a/examples/basic_vlm_inference.py +++ /dev/null @@ -1,135 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -import argparse - -import requests -from PIL import Image -from transformers import AutoProcessor, TextStreamer - -from QEfficient import QEFFAutoModelForImageTextToText - - -def run_model( - model_name, - query, - image_url, - kv_offload=True, - prefill_seq_len=32, - ctx_len=512, - generation_len=128, - img_size=336, - num_cores=16, - num_devices=1, -): - ## STEP 1: Load the Processor and Model - - processor = AutoProcessor.from_pretrained(model_name) - - # `kv_offload` determines Single QPC vs Dual QPC mode: - # - Single QPC (kv_offload=False): Entire model runs in one QPC - # - Dual QPC (kv_offload=True): Vision encoder and language model run in separate QPCs - # with outputs transferred via host for flexibility - - model = QEFFAutoModelForImageTextToText.from_pretrained( - model_name, attn_implementation="eager", kv_offload=kv_offload - ) - - ## STEP 2: Export & Compile the Model - - model.compile( - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - img_size=img_size, - num_cores=num_cores, - num_devices=num_devices, - mxfp6_matmul=False, - use_onnx_subfunctions=True, - ) - - ## STEP 3: Load and Process the Inputs for Inference - # Note: the message format would change for different model - image = Image.open(requests.get(image_url, stream=True).raw) - messages = [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": query}, - ], - } - ] - input_text = [processor.apply_chat_template(messages, add_generation_prompt=True)] - - inputs = processor( - text=input_text, - images=image, - return_tensors="pt", - add_special_tokens=False, - padding="max_length", - max_length=prefill_seq_len, - ) - - ## STEP 4: Run Inference on the Compiled Model - - streamer = TextStreamer(processor.tokenizer) - model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) - - -def main(): - parser = argparse.ArgumentParser(description="Vision-Language Model (VLM) inference") - parser.add_argument( - "--model-name", - type=str, - default="Qwen/Qwen2.5-VL-3B-Instruct", - help="HuggingFace VLM model ID", - ) - parser.add_argument( - "--query", - type=str, - default="Describe this image.", - help="Text query/question about the image", - ) - parser.add_argument( - "--image-url", - type=str, - default="https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", - help="URL of the image to process", - ) - parser.add_argument( - "--kv-offload", - action="store_true", - default=True, - help="Enable Dual QPC mode (vision encoder and LM in separate QPCs)", - ) - parser.add_argument("--prefill-seq-len", type=int, default=128, help="Prefill sequence length") - parser.add_argument("--ctx-len", type=int, default=3000, help="Context length") - parser.add_argument("--generation-len", type=int, default=128, help="Number of tokens to generate") - parser.add_argument("--img-size", type=int, default=336, help="Image size for processing") - parser.add_argument("--num-cores", type=int, default=16, help="Number of cores") - parser.add_argument("--num-devices", type=int, default=1, help="Number of devices") - args = parser.parse_args() - - print(f"Running VLM inference with model: {args.model_name}") - print(f"KV offload (Dual QPC mode): {args.kv_offload}") - - run_model( - model_name=args.model_name, - query=args.query, - image_url=args.image_url, - kv_offload=args.kv_offload, - prefill_seq_len=args.prefill_seq_len, - ctx_len=args.ctx_len, - generation_len=args.generation_len, - img_size=args.img_size, - num_cores=args.num_cores, - num_devices=args.num_devices, - ) - - -if __name__ == "__main__": - main() diff --git a/test.py b/test.py deleted file mode 100644 index fc05dfa34a..0000000000 --- a/test.py +++ /dev/null @@ -1,13 +0,0 @@ -from transformers import AutoTokenizer - -from QEfficient import QEFFAutoModelForCausalLM - -model_name = "meta-llama/Llama-3.2-1B" -model1 = QEFFAutoModelForCausalLM.from_pretrained(model_name) -# model2=QEFFAutoModelForCausalLM.from_pretrained(model_name, num_hidden_layers = 2) -# with_sub_func_onnx = model1.export(use_onnx_subfunctions=True) -model1.compile(num_devices=1, num_cores=16, use_onnx_subfunctions=True) -hash_0_1 = model1.export_hash -inputs = "Help me with this" -tokenizer = AutoTokenizer.from_pretrained(model_name) -generation_00 = model1.generate(prompts=["Help me with this"], tokenizer=tokenizer) \ No newline at end of file From dc2d645959b0b783fda52a1ad455a05988de95b0 Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Thu, 26 Mar 2026 05:51:23 +0000 Subject: [PATCH 8/9] Modified llama swiftkv Signed-off-by: abhishek-singh591 --- .../llama_swiftkv/modeling_llama_swiftkv.py | 30 +++++++++++---- .../causallm/example_pytorch_transforms.py | 12 +++--- .../models/test_single_subfunction.py | 38 +++++++++---------- 3 files changed, 47 insertions(+), 33 deletions(-) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 537d358a38..2e8a526d79 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -144,7 +144,7 @@ def forward( class QEffLlamaSwiftKVDecoderLayer(nn.Module): - def __init__(self, config: QEffLlamaSwiftKVConfig, layer_idx) -> None: + def __init__(self, config: QEffLlamaSwiftKVConfig, layer_idx, sin_cached, cos_cached) -> None: super().__init__() self.hidden_size = config.hidden_size self.num_key_value_heads = config.num_key_value_heads @@ -152,6 +152,8 @@ def __init__(self, config: QEffLlamaSwiftKVConfig, layer_idx) -> None: self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.sin_cached = sin_cached + self.cos_cached = cos_cached def forward( self, @@ -167,6 +169,10 @@ def forward( # Self Attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) + if sin_cached is None: + sin_cached = self.sin_cached + if cos_cached is None: + cos_cached = self.cos_cached hidden_states, past_key_values = self.self_attn( hidden_states=hidden_states, @@ -197,23 +203,24 @@ def __init__(self, config: QEffLlamaSwiftKVConfig): self.vocab_size = config.vocab_size self.config = config + self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, None) + sin_cached, cos_cached = self.sin_cached, self.cos_cached self.layers = torch.nn.ModuleList( [ QEffLlamaDecoderLayer(config=config, layer_idx=idx) if idx < config.num_key_value_layers - else QEffLlamaSwiftKVDecoderLayer(config=config, layer_idx=idx) + else QEffLlamaSwiftKVDecoderLayer( + config=config, layer_idx=idx, sin_cached=sin_cached, cos_cached=cos_cached + ) for idx in range(config.num_hidden_layers) ] ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm_swiftkv = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def __qeff_init__(self): - self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) - self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) - self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) - def _run_swiftkv_layers( self, hidden_states: torch.Tensor, @@ -226,7 +233,14 @@ def _run_swiftkv_layers( for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): layer = self.layers[layer_idx] hidden_states = layer( - hidden_states, position_ids, past_key_values, comp_ctx_lengths, causal_mask, batch_index + hidden_states, + position_ids, + past_key_values, + comp_ctx_lengths, + causal_mask, + batch_index, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) hidden_states = self.norm(hidden_states) diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py index ff62588f9c..503efc12dc 100644 --- a/examples/onboarding_guide/causallm/example_pytorch_transforms.py +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -27,12 +27,6 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union -from QEfficient.transformers.models.blueprint.modeling_blueprint import ( - QEffBlueprintAttention, - QEffBlueprintDecoderLayer, - QEffBlueprintForCausalLM, - QEffBlueprintModel, -) from torch import nn # Example imports for three representative models @@ -62,6 +56,12 @@ from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform from QEfficient.customop import CustomRMSNormAIC from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaDecoderLayer, diff --git a/tests/transformers/models/test_single_subfunction.py b/tests/transformers/models/test_single_subfunction.py index 0d2fc3bc64..f17edab654 100644 --- a/tests/transformers/models/test_single_subfunction.py +++ b/tests/transformers/models/test_single_subfunction.py @@ -16,26 +16,26 @@ torch.manual_seed(42) configs = [ - # ("gpt2", 256, 2, 4, 128, 512, 127, {}), - # # ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("gpt2", 256, 2, 4, 128, 512, 127, {}), + # ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), ("falcon", 256, 2, 4, 128, 512, 127, {}), - # ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), - # ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # # ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("mpt", 256, 2, 4, 128, 512, 127, {}), - # # ("phi", 256, 2, 4, 128, 512, 127, {}), - # ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), - # ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("qwen3", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # # ("starcoder2", 256, 2, 4, 128, 512, 127, {}), - # ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # # ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("qwen3_moe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # # ("granitemoe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("gemma", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - # ("gemma2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mpt", 256, 2, 4, 128, 512, 127, {}), + # ("phi", 256, 2, 4, 128, 512, 127, {}), + ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), + ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("qwen3", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("starcoder2", 256, 2, 4, 128, 512, 127, {}), + ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("qwen3_moe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("granitemoe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("gemma", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("gemma2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] configs = [ From fce8237799786b76cf723d507fa73ffe0a769322 Mon Sep 17 00:00:00 2001 From: abhishek-singh591 Date: Thu, 26 Mar 2026 05:54:05 +0000 Subject: [PATCH 9/9] Modified llama swiftkv Signed-off-by: abhishek-singh591 --- .../causallm/example_pytorch_transforms.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py index 503efc12dc..ff62588f9c 100644 --- a/examples/onboarding_guide/causallm/example_pytorch_transforms.py +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -27,6 +27,12 @@ from types import MethodType from typing import Callable, Optional, Tuple, Union +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) from torch import nn # Example imports for three representative models @@ -56,12 +62,6 @@ from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform from QEfficient.customop import CustomRMSNormAIC from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function -from QEfficient.transformers.models.blueprint.modeling_blueprint import ( - QEffBlueprintAttention, - QEffBlueprintDecoderLayer, - QEffBlueprintForCausalLM, - QEffBlueprintModel, -) from QEfficient.transformers.models.llama.modeling_llama import ( QEffLlamaAttention, QEffLlamaDecoderLayer,