Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 6 additions & 17 deletions QEfficient/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,30 +60,17 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)


def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
def qeff_apply_rotary_pos_emb(q, k, cos, sin):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)

# Apply rotation
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
Expand Down Expand Up @@ -127,7 +114,7 @@ def forward(
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)

# kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos_cached, sin_cached, position_ids)
query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos_cached, sin_cached)

if layer_past is not None:
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
Expand Down Expand Up @@ -301,6 +288,8 @@ def forward(

all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
sin = self.sin_cached[position_ids].unsqueeze(1)
cos = self.cos_cached[position_ids].unsqueeze(1)

for i, block in enumerate(self.h):
if output_hidden_states:
Expand All @@ -319,8 +308,8 @@ def forward(
output_attentions=output_attentions,
alibi=alibi,
cache_position=cache_position,
sin_cached=self.sin_cached,
cos_cached=self.cos_cached,
sin_cached=sin,
cos_cached=cos,
)

hidden_states = outputs[0]
Expand Down
26 changes: 6 additions & 20 deletions QEfficient/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,30 +56,17 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)


def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
def qeff_apply_rotary_pos_emb(q, k, cos, sin):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)

# Apply rotation
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
Expand Down Expand Up @@ -138,10 +125,7 @@ def forward(
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

# kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
query_states, key_states = qeff_apply_rotary_pos_emb(
query_states, key_states, cos_cached, sin_cached, position_ids
)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached)

if past_key_values is not None:
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
Expand Down Expand Up @@ -295,6 +279,8 @@ def forward(

# decoder layers
all_hidden_states = () if output_hidden_states else None
sin = self.sin_cached[position_ids].unsqueeze(1)
cos = self.cos_cached[position_ids].unsqueeze(1)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
Expand All @@ -309,8 +295,8 @@ def forward(
batch_index=batch_index,
use_cache=use_cache,
cache_position=cache_position,
sin_cached=self.sin_cached,
cos_cached=self.cos_cached,
sin_cached=sin,
cos_cached=cos,
**kwargs,
)

Expand Down
26 changes: 6 additions & 20 deletions QEfficient/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,30 +59,17 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)


def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
def qeff_apply_rotary_pos_emb(q, k, cos, sin):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)

# Apply rotation
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
Expand Down Expand Up @@ -145,10 +132,7 @@ def forward(
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

# kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
query_states, key_states = qeff_apply_rotary_pos_emb(
query_states, key_states, cos_cached, sin_cached, position_ids
)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached)

if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
Expand Down Expand Up @@ -339,6 +323,8 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
sin = self.sin_cached[position_ids].unsqueeze(1)
cos = self.cos_cached[position_ids].unsqueeze(1)

for decoder_layer in self.layers:
if output_hidden_states:
Expand All @@ -354,8 +340,8 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
sin_cached=self.sin_cached,
cos_cached=self.cos_cached,
sin_cached=sin,
cos_cached=cos,
**kwargs,
)

Expand Down
45 changes: 14 additions & 31 deletions QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
def qeff_apply_rotary_pos_emb(q, k, cos, sin):
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).

Explanation:
Expand All @@ -552,25 +552,10 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
mrope_section(`List(int)`):
Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""

cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)

q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)

Expand Down Expand Up @@ -748,9 +733,7 @@ def forward(
hidden_shape = (*input_shape, -1, self.head_dim)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states, key_states = qeff_apply_rotary_pos_emb(
query_states, key_states, cos_cached, sin_cached, position_ids
)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached)

if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
Expand Down Expand Up @@ -832,9 +815,7 @@ def forward(
hidden_shape = (*input_shape, -1, self.head_dim)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states, key_states = qeff_apply_rotary_pos_emb(
query_states, key_states, cos_cached, sin_cached, position_ids
)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached)

if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
Expand Down Expand Up @@ -912,9 +893,7 @@ def forward(
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states, key_states = qeff_apply_rotary_pos_emb(
query_states, key_states, cos_cached, sin_cached, position_ids
)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached)

if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
Expand Down Expand Up @@ -1071,6 +1050,8 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
sin = self.sin_cached[position_ids].unsqueeze(1)
cos = self.cos_cached[position_ids].unsqueeze(1)

for decoder_layer in self.layers:
if output_hidden_states:
Expand All @@ -1086,8 +1067,8 @@ def forward(
output_attentions=output_attentions,
cache_position=cache_position,
sliding_mask=sliding_mask,
sin_cached=self.sin_cached,
cos_cached=self.cos_cached,
sin_cached=sin,
cos_cached=cos,
**kwargs,
)
hidden_states = layer_outputs[0]
Expand All @@ -1112,8 +1093,8 @@ def forward(
class QEffGptOssModel(GptOssModel):
def __qeff_init__(self):
self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config)
self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached)
self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached)
self.sin_cached = torch.nn.Parameter(self.rotary_emb.sin_cached * self.rotary_emb.attention_scaling)
self.cos_cached = torch.nn.Parameter(self.rotary_emb.cos_cached * self.rotary_emb.attention_scaling)

def forward(
self,
Expand Down Expand Up @@ -1171,6 +1152,8 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
sin = self.sin_cached[position_ids].unsqueeze(1)
cos = self.cos_cached[position_ids].unsqueeze(1)

for decoder_layer in self.layers:
if output_hidden_states:
Expand All @@ -1187,8 +1170,8 @@ def forward(
output_attentions=output_attentions,
cache_position=cache_position,
sliding_mask=sliding_mask,
sin_cached=self.sin_cached,
cos_cached=self.cos_cached,
sin_cached=sin,
cos_cached=cos,
**kwargs,
)
hidden_states = layer_outputs[0]
Expand Down
26 changes: 6 additions & 20 deletions QEfficient/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,30 +54,17 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)


def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
def qeff_apply_rotary_pos_emb(q, k, cos, sin):
"""Applies Rotary Position Embedding to the query and key tensors.

Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)

# Apply rotation
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
Expand Down Expand Up @@ -131,10 +118,7 @@ def forward(
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

# kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
query_states, key_states = qeff_apply_rotary_pos_emb(
query_states, key_states, cos_cached, sin_cached, position_ids
)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos_cached, sin_cached)

if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
Expand Down Expand Up @@ -300,6 +284,8 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
sin = self.sin_cached[position_ids].unsqueeze(1)
cos = self.cos_cached[position_ids].unsqueeze(1)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
Expand All @@ -315,8 +301,8 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
sin_cached=self.sin_cached,
cos_cached=self.cos_cached,
sin_cached=sin,
cos_cached=cos,
**kwargs,
)

Expand Down
Loading
Loading