diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 4ebb2fb96e..e70f32818f 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -59,16 +59,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -108,9 +98,6 @@ class QEffFalconAttention(FalconAttention): - add new args position idx for the cache_kwargs for kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffFalconRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -125,6 +112,8 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, ): fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads @@ -137,9 +126,8 @@ def forward( key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) - query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos_cached, sin_cached, position_ids) if layer_past is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -184,6 +172,8 @@ def forward( use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, **kwargs, ): residual = hidden_states @@ -208,6 +198,8 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, ) if not self.config.new_decoder_architecture: @@ -245,6 +237,11 @@ class QEffFalconModel(FalconModel): - update causal attention mask """ + def __qeff_init__(self): + self.rotary_emb = QEffFalconRotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) + def forward( self, input_ids: torch.LongTensor = None, @@ -322,6 +319,8 @@ def forward( output_attentions=output_attentions, alibi=alibi, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) hidden_states = outputs[0] diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 260d1857a7..3bed2d00ee 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -55,16 +55,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -128,9 +118,6 @@ class QEffGemmaAttention(GemmaAttention): - add new args cache idx for the kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffGemmaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -140,6 +127,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -149,9 +138,10 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -194,6 +184,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -223,6 +215,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states @@ -243,6 +237,11 @@ class QEffGemmaModel(GemmaModel): - add new args cache idx for the kv retention """ + def __qeff_init__(self): + self.rotary_emb = QEffGemmaRotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: torch.LongTensor = None, @@ -310,6 +309,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 6dee8c85dd..8e2e823c7f 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -58,16 +58,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -135,9 +125,6 @@ class QEffGemma2Attention(Gemma2Attention): - add new args cache idx for the kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffGemma2RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -147,6 +134,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -156,15 +145,16 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { - "sin": sin, - "cos": cos, + "sin": sin_cached, + "cos": cos_cached, "batch_index": batch_index, "position_ids": position_ids, } @@ -208,6 +198,8 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -241,6 +233,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -271,6 +265,11 @@ class QEffGemma2Model(Gemma2Model): - add new args cache idx for the kv retention """ + def __qeff_init__(self): + self.rotary_emb = QEffGemma2RotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: torch.LongTensor = None, @@ -355,6 +354,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index e8f5fa89b3..6f4e1d8c43 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -527,16 +527,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -737,9 +727,6 @@ def opt_eager_attention_forward_blocked( class QEffPrefillOnlyChunkedGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -751,6 +738,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, sliding_mask=None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] @@ -759,16 +748,15 @@ def forward( hidden_shape = (*input_shape, -1, self.head_dim) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): - max_seq_len_cached = 32 * 1024 - cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { - "sin": sin, - "cos": cos, + "sin": sin_cached, + "cos": cos_cached, "batch_index": batch_index, "position_ids": position_ids, "config": self.config, @@ -823,9 +811,6 @@ def forward( class QEffPrefillOnlyGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -837,6 +822,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, sliding_mask=None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] @@ -845,16 +832,15 @@ def forward( hidden_shape = (*input_shape, -1, self.head_dim) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): - max_seq_len_cached = 32 * 1024 - cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { - "sin": sin, - "cos": cos, + "sin": sin_cached, + "cos": cos_cached, "batch_index": batch_index, "position_ids": position_ids, "config": self.config, @@ -905,9 +891,6 @@ def forward( class QEffGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -919,6 +902,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, sliding_mask=None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] @@ -927,16 +912,15 @@ def forward( query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): - max_seq_len_cached = 32 * 1024 - cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { - "sin": sin, - "cos": cos, + "sin": sin_cached, + "cos": cos_cached, "batch_index": batch_index, "position_ids": position_ids, "config": self.config, @@ -986,6 +970,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC sliding_mask=None, + sin_cached=None, + cos_cached=None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor]: residual = hidden_states @@ -1002,6 +988,8 @@ def forward( cache_position=cache_position, position_embeddings=position_embeddings, sliding_mask=sliding_mask, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states @@ -1024,6 +1012,11 @@ def forward( class QEffPrefillOnlyGptOssModel(GptOssModel): + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1093,6 +1086,8 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, sliding_mask=sliding_mask, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) hidden_states = layer_outputs[0] @@ -1115,6 +1110,11 @@ def forward( class QEffGptOssModel(GptOssModel): + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1187,6 +1187,8 @@ def forward( output_attentions=output_attentions, cache_position=cache_position, sliding_mask=sliding_mask, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) hidden_states = layer_outputs[0] diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index 8a32c52ef2..c2af97f55d 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -53,16 +53,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -121,9 +111,6 @@ def eager_attention_forward( class QEffGraniteAttention(GraniteAttention): - def __qeff_init__(self): - self.rotary_emb = QEffGraniteRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -133,6 +120,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -142,15 +131,16 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { - "sin": sin, - "cos": cos, + "sin": sin_cached, + "cos": cos_cached, "batch_index": batch_index, "position_ids": position_ids, } @@ -192,6 +182,8 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + sin_cached=None, + cos_cached=None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -230,6 +222,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states * self.residual_multiplier @@ -249,6 +243,11 @@ def forward( class QEffGraniteModel(GraniteModel): + def __qeff_init__(self): + self.rotary_emb = QEffGraniteRotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) + def forward( self, input_ids: torch.LongTensor = None, @@ -316,6 +315,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 935df7c2d9..82bb8533a6 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -59,16 +59,6 @@ def _set_cos_sin_cache(self, seq_len: int, device=None, dtype=None): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x: torch.Tensor, seq_len: int = None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb( q: torch.Tensor, @@ -111,9 +101,6 @@ def qeff_apply_rotary_pos_emb( class QEffGraniteMoeAttention(GraniteMoeAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffGraniteMoeRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -126,6 +113,8 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -137,14 +126,15 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { - "sin": sin, - "cos": cos, + "sin": sin_cached, + "cos": cos_cached, "cache_position": cache_position, "batch_index": batch_index, "position_ids": position_ids, @@ -214,6 +204,8 @@ def forward( cache_position: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + sin_cached=None, + cos_cached=None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -255,6 +247,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) @@ -287,6 +281,11 @@ class QEffGraniteMoeModel(GraniteMoeModel): """ + def __qeff_init__(self): + self.rotary_emb = QEffGraniteMoeRotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) + def forward( self, input_ids: torch.LongTensor = None, @@ -356,6 +355,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) else: layer_outputs = decoder_layer( @@ -368,6 +369,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) hidden_states = layer_outputs[0] diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 57bccdb1bb..5b501d36fa 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -54,16 +54,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -198,9 +188,6 @@ def eager_attention_forward_blockedKV( class QEffLlamaAttention(LlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -212,6 +199,8 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, num_kv_blocks: Optional[torch.Tensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -226,10 +215,11 @@ def forward( key_states = self.k_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: if num_kv_blocks is not None: @@ -287,6 +277,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -303,6 +295,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states @@ -321,6 +315,11 @@ class QEffLlamaModel(LlamaModel): Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py """ + def __qeff_init__(self): + self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) + def forward( self, input_ids: torch.LongTensor = None, @@ -380,6 +379,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index e219d5e03a..2e8a526d79 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -82,8 +82,6 @@ def __init__(self, config: QEffLlamaSwiftKVConfig, layer_idx) -> None: ) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = QEffLlamaRotaryEmbedding(config=config) - def forward( self, hidden_states: torch.Tensor, @@ -92,6 +90,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: torch.Tensor = None, batch_index: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, ) -> torch.Tensor: bsz, q_len, _ = hidden_states.size() q_len = 1 # as we always run this for single token @@ -110,13 +110,12 @@ def forward( if comp_ctx_lengths is not None: attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs["CCL"] = attention_mask.shape[-1] - kv_seq_len = past_key_value.get_seq_length(self.layer_idx) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx) key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) position_ids = position_ids[torch.arange(bsz), position_ids.to(torch.int32).argmax(1)].unsqueeze(1) query_states, _ = qeff_apply_rotary_pos_emb( - query_states, torch.empty_like(query_states), cos, sin, position_ids + query_states, torch.empty_like(query_states), cos_cached, sin_cached, position_ids ) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -145,7 +144,7 @@ def forward( class QEffLlamaSwiftKVDecoderLayer(nn.Module): - def __init__(self, config: QEffLlamaSwiftKVConfig, layer_idx) -> None: + def __init__(self, config: QEffLlamaSwiftKVConfig, layer_idx, sin_cached, cos_cached) -> None: super().__init__() self.hidden_size = config.hidden_size self.num_key_value_heads = config.num_key_value_heads @@ -153,6 +152,8 @@ def __init__(self, config: QEffLlamaSwiftKVConfig, layer_idx) -> None: self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.sin_cached = sin_cached + self.cos_cached = cos_cached def forward( self, @@ -162,10 +163,16 @@ def forward( comp_ctx_lengths, causal_mask, batch_index: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) + if sin_cached is None: + sin_cached = self.sin_cached + if cos_cached is None: + cos_cached = self.cos_cached hidden_states, past_key_values = self.self_attn( hidden_states=hidden_states, @@ -174,6 +181,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, batch_index=batch_index, + sin_cached=sin_cached, + cos_cached=cos_cached, ) hidden_states = residual + hidden_states @@ -194,12 +203,18 @@ def __init__(self, config: QEffLlamaSwiftKVConfig): self.vocab_size = config.vocab_size self.config = config + self.rotary_emb = QEffLlamaRotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, None) + sin_cached, cos_cached = self.sin_cached, self.cos_cached self.layers = torch.nn.ModuleList( [ QEffLlamaDecoderLayer(config=config, layer_idx=idx) if idx < config.num_key_value_layers - else QEffLlamaSwiftKVDecoderLayer(config=config, layer_idx=idx) + else QEffLlamaSwiftKVDecoderLayer( + config=config, layer_idx=idx, sin_cached=sin_cached, cos_cached=cos_cached + ) for idx in range(config.num_hidden_layers) ] ) @@ -218,7 +233,14 @@ def _run_swiftkv_layers( for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): layer = self.layers[layer_idx] hidden_states = layer( - hidden_states, position_ids, past_key_values, comp_ctx_lengths, causal_mask, batch_index + hidden_states, + position_ids, + past_key_values, + comp_ctx_lengths, + causal_mask, + batch_index, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) hidden_states = self.norm(hidden_states) @@ -347,6 +369,8 @@ def forward( batch_index=batch_index, output_attentions=False, use_cache=True, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) bsz, q_len, _ = hidden_states.size() @@ -370,10 +394,11 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_values.get_seq_length(self_attn.layer_idx) + # kv_seq_len = past_key_values.get_seq_length(self_attn.layer_idx) - cos, sin = self_attn.rotary_emb(value_states, seq_len=kv_seq_len) - _, key_states = qeff_apply_rotary_pos_emb(torch.empty_like(key_states), key_states, cos, sin, position_ids) + _, key_states = qeff_apply_rotary_pos_emb( + torch.empty_like(key_states), key_states, self.cos_cached, self.sin_cached, position_ids + ) cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} past_key_values.write_only(key_states, value_states, self_attn.layer_idx, cache_kwargs) diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 47107384ed..76a7d24c64 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -58,16 +58,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -131,9 +121,6 @@ class QEffMistralAttention(MistralAttention): - add new args cache idx for the kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffMistralRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -146,6 +133,8 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: tuple[torch.Tensor, torch.Tensor] = None, # kept here for BC + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -159,9 +148,10 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -205,6 +195,8 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -236,6 +228,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states @@ -256,6 +250,11 @@ class QEffMistralModel(MistralModel): - add new args cache idx for the kv retention """ + def __qeff_init__(self): + self.rotary_emb = QEffMistralRotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: torch.LongTensor = None, @@ -328,6 +327,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index 680c839ae5..e59a3be534 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -60,16 +60,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -128,9 +118,6 @@ def eager_attention_forward( class QEffMixtralAttention(MixtralAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffMixtralRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -139,6 +126,8 @@ def forward( past_key_value: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -155,9 +144,11 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx) + + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -265,6 +256,8 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -301,6 +294,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states @@ -323,6 +318,11 @@ class QEffMixtralModel(MixtralModel): - update causal attention mask """ + def __qeff_init__(self): + self.rotary_emb = QEffMixtralRotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) + # Ignore copy def forward( self, @@ -397,6 +397,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 3cba022b48..642fb4bb7c 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -123,16 +123,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -241,9 +231,6 @@ class QEffMllamaTextSelfAttention(MllamaTextSelfAttention): - add new args cache idx for the kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffMllamaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -255,6 +242,8 @@ def forward( position_embeddings: torch.Tensor = None, use_cache: bool = False, cache_position=None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ): bsz, q_len, _ = hidden_states.size() @@ -274,10 +263,11 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: cache_kwargs = { @@ -326,6 +316,8 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + sin_cached=None, + cos_cached=None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -361,6 +353,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + sin_cached=sin_cached, + cos_cached=cos_cached, ) hidden_states = residual + hidden_states @@ -465,6 +459,8 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.Tensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -602,6 +598,11 @@ class QEffMllamaTextModel(MllamaTextModel): - add new args cache idx for the kv retention """ + def __qeff_init__(self): + self.rotary_emb = QEffMllamaRotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -676,6 +677,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) hidden_states = self.norm(hidden_states) diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index c79ad7faee..22834d2926 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -54,16 +54,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -123,9 +113,6 @@ def eager_attention_forward( class QEffOlmo2Attention(Olmo2Attention): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __qeff_init__(self): - self.rotary_emb = QEffOlmo2RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -135,6 +122,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -148,11 +137,11 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = key_states.shape[-2] + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -198,6 +187,8 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -213,6 +204,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -233,6 +226,11 @@ class QEffOlmo2Model(Olmo2Model): - add new args cache idx for the kv retention """ + def __qeff_init__(self): + self.rotary_emb = QEffOlmo2RotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) + def forward( self, input_ids: torch.LongTensor = None, @@ -297,6 +295,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index b48ab28979..b18dbcd5c0 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -52,16 +52,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -129,9 +119,6 @@ class QEffPhi3Attention(Phi3Attention): - add new args position idx for the cache_kwargs for kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffPhi3RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -142,6 +129,8 @@ def forward( past_key_value: Optional[Cache] = None, comp_ctx_lengths: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -157,10 +146,11 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: cache_kwargs = { @@ -207,6 +197,8 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -244,6 +236,8 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) @@ -265,6 +259,11 @@ class QEffPhi3Model(Phi3Model): - update causal attention mask """ + def __qeff_init__(self): + self.rotary_emb = QEffPhi3RotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: torch.LongTensor = None, @@ -324,6 +323,8 @@ def forward( comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 841df65269..df7421c466 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -58,16 +58,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). @@ -141,9 +131,6 @@ class QEffQwen2Attention(Qwen2Attention): - add new args position idx for the cache_kwargs for kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffQwen2RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -153,6 +140,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -162,9 +151,10 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -208,6 +198,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -240,6 +232,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states @@ -261,6 +255,11 @@ class QEffQwen2Model(Qwen2Model): - update causal attention mask """ + def __qeff_init__(self): + self.rotary_emb = QEffQwen2RotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) + def forward( self, input_ids: torch.LongTensor = None, @@ -324,6 +323,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) hidden_states = self.norm(hidden_states) diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index d6bfbda81b..f333302bc0 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -348,16 +348,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def eager_attention_forward_blockedKV( module: nn.Module, @@ -564,9 +554,6 @@ class QEffQwen2_5_VLAttention(Qwen2_5_VLAttention): and "Generating Long Sequences with Sparse Transformers". """ - def __qeff_init__(self): - self.rotary_emb = QEffQwen2_5_VLRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -579,6 +566,8 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, num_kv_blocks: Optional[torch.Tensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -591,21 +580,19 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + # kv_seq_len = key_states.shape[-2] + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = qeff_apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids[1:], self.rope_scaling["mrope_section"] + query_states, key_states, cos_cached, sin_cached, position_ids[1:], self.rope_scaling["mrope_section"] ) if past_key_value is not None: if num_kv_blocks is not None: cache_kwargs = { - "sin": sin, - "cos": cos, + "sin": sin_cached, + "cos": cos_cached, "batch_index": batch_index, "position_ids": position_ids[0], "past_seen_tokens": past_seen_tokens, @@ -614,8 +601,8 @@ def forward( else: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { - "sin": sin, - "cos": cos, + "sin": sin_cached, + "cos": cos_cached, "batch_index": batch_index, "position_ids": position_ids[0], } @@ -661,6 +648,8 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, # position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -701,6 +690,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states @@ -723,6 +714,11 @@ def forward( class QEffQwen2_5_VLTextModel(Qwen2_5_VLTextModel): + def __qeff_init__(self): + self.rotary_emb = QEffQwen2_5_VLRotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -785,6 +781,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, **kwargs, ) diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index ccc4bbac29..4202f52e18 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -58,16 +58,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). @@ -142,9 +132,6 @@ class QEffQwen3Attention(Qwen3Attention): - add new args position idx for the cache_kwargs for kv retention """ - def __qeff_init__(self): - self.rotary_emb = QEffQwen3RotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -154,6 +141,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -163,9 +152,10 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -209,6 +199,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -241,6 +233,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, **kwargs, ) hidden_states = residual + hidden_states @@ -262,6 +256,11 @@ class QEffQwen3Model(Qwen3Model): - update causal attention mask """ + def __qeff_init__(self): + self.rotary_emb = QEffQwen3RotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling) + def forward( self, input_ids: torch.LongTensor = None, @@ -325,6 +324,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) hidden_states = self.norm(hidden_states) diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 6bdd5e2439..fb7320ff6a 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -50,16 +50,6 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -188,9 +178,6 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens class QEffQwen3MoeAttention(Qwen3MoeAttention): - def __qeff_init__(self): - self.rotary_emb = QEffQwen3MoeRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -200,6 +187,8 @@ def forward( comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + cos_cached: Optional[torch.Tensor] = None, + sin_cached: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -209,9 +198,10 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos_cached, sin_cached, position_ids + ) if past_key_value is not None: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} @@ -247,6 +237,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + sin_cached=None, + cos_cached=None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -279,6 +271,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=sin_cached, + cos_cached=cos_cached, ) hidden_states = residual + hidden_states @@ -296,6 +290,11 @@ def forward( class QEffQwen3MoeModel(Qwen3MoeModel): + def __qeff_init__(self): + self.rotary_emb = QEffQwen3MoeRotaryEmbedding(config=self.config) + self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached) + self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached) + def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -349,6 +348,8 @@ def forward( batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, + sin_cached=self.sin_cached, + cos_cached=self.cos_cached, ) hidden_states = self.norm(hidden_states) diff --git a/tests/transformers/models/test_single_subfunction.py b/tests/transformers/models/test_single_subfunction.py new file mode 100644 index 0000000000..f17edab654 --- /dev/null +++ b/tests/transformers/models/test_single_subfunction.py @@ -0,0 +1,94 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import onnx +import pytest +import torch +from transformers import AutoConfig, AutoModelForCausalLM + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM +from QEfficient.utils.device_utils import get_available_device_id + +torch.manual_seed(42) + +configs = [ + ("gpt2", 256, 2, 4, 128, 512, 127, {}), + # ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("falcon", 256, 2, 4, 128, 512, 127, {}), + ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mpt", 256, 2, 4, 128, 512, 127, {}), + # ("phi", 256, 2, 4, 128, 512, 127, {}), + ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), + ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("qwen3", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("starcoder2", 256, 2, 4, 128, 512, 127, {}), + ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("qwen3_moe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("granitemoe", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("gemma", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("gemma2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), +] + +configs = [ + AutoConfig.for_model( + model_name, + max_position_embeddings=max_position_embeddings, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + vocab_size=vocab_size, + **additional_params, + ) + for ( + model_name, + max_position_embeddings, + num_hidden_layers, + num_attention_heads, + hidden_size, + intermediate_size, + vocab_size, + additional_params, + ) in configs +] + +model_kwargs = {"attn_implementation": "eager"} +config_ids = [x.model_type for x in configs] + + +def get_function(onnx_path): + """Check if ONNX model contains QEffGPT2Block function definition.""" + model = onnx.load(onnx_path, load_external_data=False) + function_names = [f.name for f in model.functions] + return function_names + + +@pytest.mark.on_qaic +@pytest.mark.feature +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_subfunction_vs_nonsubfunction(config, tmp_path): + model_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb=False) + with_sub_func_onnx = model_0_0.export(tmp_path, use_onnx_subfunctions=True, offload_pt_weights=False) + + functions_names = get_function(with_sub_func_onnx) + print(functions_names) + + keywords = ["DecoderLayer", "Block", "Layer"] + filtered = [name for name in functions_names if any(key in name for key in keywords)] + + if len(filtered) > 1: + raise AssertionError(f"function definition, but found {len(functions_names)} functions: {functions_names}") + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + compile_params = {"prefill_seq_len": 8, "ctx_len": 16} + model_0_0.compile(onnx_path=with_sub_func_onnx, **compile_params, use_onnx_subfunctions=True)