Skip to content
7 changes: 7 additions & 0 deletions src/transformers/integrations/hub_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
LayerRepository,
Mode,
get_kernel,
load_kernel,
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
Expand Down Expand Up @@ -115,6 +116,9 @@

register_kernel_mapping(_KERNEL_MAPPING)

# Preload the rotary kernel as it's used in many models.
rotary_kernel = get_kernel(repo_id="kernels-community/rotary")

except ImportError:
_kernels_available = False

Expand All @@ -138,6 +142,8 @@ def replace_kernel_forward_from_hub(*args, **kwargs):
def register_kernel_mapping(*args, **kwargs):
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")

rotary_kernel = None


def is_kernel(attn_implementation: Optional[str]) -> bool:
"""Check whether `attn_implementation` matches a kernel pattern from the hub."""
Expand Down Expand Up @@ -201,4 +207,5 @@ def load_and_register_kernel(attn_implementation: str) -> None:
"use_kernel_forward_from_hub",
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
"rotary_kernel",
]
40 changes: 39 additions & 1 deletion src/transformers/models/qwen3/modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...integrations.hub_kernels import rotary_kernel
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import (
Expand Down Expand Up @@ -117,6 +118,34 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed, k_embed


def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""
Rotary kernel implementation wrapper
Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)

q_rotated = q.clone()
k_rotated = k.clone()

# Get half dimension for rotation
half_dim = q.shape[-1] // 2
q1 = q_rotated[..., :half_dim]
q2 = q_rotated[..., half_dim:]
k1 = k_rotated[..., :half_dim]
k2 = k_rotated[..., half_dim:]
if cos.shape[-1] != half_dim:
# Trim cos/sin to match half_dim
cos = cos[..., :half_dim]
sin = sin[..., :half_dim]

# Apply rotary embedding using our kernel
rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False)
rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False)
return q_rotated, k_rotated


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
Expand Down Expand Up @@ -192,6 +221,7 @@ def forward(
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
use_kernels: Optional[bool] = False,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
Expand All @@ -202,7 +232,10 @@ def forward(
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if use_kernels and rotary_kernel:
query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position)
else:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
Expand Down Expand Up @@ -252,6 +285,7 @@ def forward(
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
use_kernels: Optional[bool] = False,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
residual = hidden_states
Expand All @@ -265,6 +299,7 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
use_kernels=use_kernels,
**kwargs,
)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -362,6 +397,7 @@ def forward(
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
use_kernels: Optional[bool] = False,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
Expand Down Expand Up @@ -415,6 +451,7 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
use_kernels=use_kernels,
**kwargs,
)

Expand Down Expand Up @@ -485,6 +522,7 @@ def forward(
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
use_kernels=self.use_kernels,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen3/modular_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch

from ...cache_utils import Cache
from ...integrations.hub_kernels import rotary_kernel
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import CausalLMOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
Expand Down