diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 5be21e2f9a51..8a5c237612a6 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -25,6 +25,7 @@ LayerRepository, Mode, get_kernel, + load_kernel, register_kernel_mapping, replace_kernel_forward_from_hub, use_kernel_forward_from_hub, @@ -115,6 +116,9 @@ register_kernel_mapping(_KERNEL_MAPPING) + # Preload the rotary kernel as it's used in many models. + rotary_kernel = get_kernel(repo_id="kernels-community/rotary") + except ImportError: _kernels_available = False @@ -138,6 +142,8 @@ def replace_kernel_forward_from_hub(*args, **kwargs): def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") + rotary_kernel = None + def is_kernel(attn_implementation: Optional[str]) -> bool: """Check whether `attn_implementation` matches a kernel pattern from the hub.""" @@ -201,4 +207,5 @@ def load_and_register_kernel(attn_implementation: str) -> None: "use_kernel_forward_from_hub", "register_kernel_mapping", "replace_kernel_forward_from_hub", + "rotary_kernel", ] diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 81b16c4ee6b6..4174989970ff 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -28,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub +from ...integrations.hub_kernels import rotary_kernel from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -117,6 +118,34 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Rotary kernel implementation wrapper + Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rotated = q.clone() + k_rotated = k.clone() + + # Get half dimension for rotation + half_dim = q.shape[-1] // 2 + q1 = q_rotated[..., :half_dim] + q2 = q_rotated[..., half_dim:] + k1 = k_rotated[..., :half_dim] + k2 = k_rotated[..., half_dim:] + if cos.shape[-1] != half_dim: + # Trim cos/sin to match half_dim + cos = cos[..., :half_dim] + sin = sin[..., :half_dim] + + # Apply rotary embedding using our kernel + rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return q_rotated, k_rotated + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -192,6 +221,7 @@ def forward( attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, + use_kernels: Optional[bool] = False, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] @@ -202,7 +232,10 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if use_kernels and rotary_kernel: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -252,6 +285,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + use_kernels: Optional[bool] = False, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states @@ -265,6 +299,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + use_kernels=use_kernels, **kwargs, ) hidden_states = residual + hidden_states @@ -362,6 +397,7 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + use_kernels: Optional[bool] = False, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): @@ -415,6 +451,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + use_kernels=use_kernels, **kwargs, ) @@ -485,6 +522,7 @@ def forward( inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, + use_kernels=self.use_kernels, **kwargs, ) diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index f1e38841faf4..ef764b0e90b0 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -19,6 +19,7 @@ import torch from ...cache_utils import Cache +from ...integrations.hub_kernels import rotary_kernel from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS