From cae1569b42b8573f85a23b9f642c25a2d5de56ec Mon Sep 17 00:00:00 2001 From: "Hu, Yabai" Date: Thu, 25 Sep 2025 02:00:00 +0000 Subject: [PATCH 1/7] add rotary kernel support for Qwen3 model Signed-off-by: Hu, Yabai --- src/transformers/integrations/hub_kernels.py | 7 ++++ .../models/qwen3/modeling_qwen3.py | 34 ++++++++++++++++++- .../models/qwen3/modular_qwen3.py | 1 + 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 5be21e2f9a51..62c9c14bfcc8 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -25,6 +25,7 @@ LayerRepository, Mode, get_kernel, + load_kernel, register_kernel_mapping, replace_kernel_forward_from_hub, use_kernel_forward_from_hub, @@ -115,6 +116,9 @@ register_kernel_mapping(_KERNEL_MAPPING) + # Preload the rotary kernel as it's used in many models. + rotary_kernel = load_kernel(repo_id="kernels-community/rotary") + except ImportError: _kernels_available = False @@ -138,6 +142,8 @@ def replace_kernel_forward_from_hub(*args, **kwargs): def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") + rotary_kernel = None + def is_kernel(attn_implementation: Optional[str]) -> bool: """Check whether `attn_implementation` matches a kernel pattern from the hub.""" @@ -201,4 +207,5 @@ def load_and_register_kernel(attn_implementation: str) -> None: "use_kernel_forward_from_hub", "register_kernel_mapping", "replace_kernel_forward_from_hub", + "rotary_kernel", ] diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 81b16c4ee6b6..f7df7b9a6b42 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -28,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub +from ...integrations.hub_kernels import rotary_kernel from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -117,6 +118,34 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Rotary kernel implementation wrapper + Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rotated = q.clone() + k_rotated = k.clone() + + # Get half dimension for rotation + half_dim = q.shape[-1] // 2 + q1 = q_rotated[..., :half_dim] + q2 = q_rotated[..., half_dim:] + k1 = k_rotated[..., :half_dim] + k2 = k_rotated[..., half_dim:] + if cos.shape[-1] != half_dim: + # Trim cos/sin to match half_dim + cos = cos[..., :half_dim] + sin = sin[..., :half_dim] + + # Apply rotary embedding using our kernel + rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return q_rotated, k_rotated + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -202,7 +231,10 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if rotary_kernel: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index f1e38841faf4..ef764b0e90b0 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -19,6 +19,7 @@ import torch from ...cache_utils import Cache +from ...integrations.hub_kernels import rotary_kernel from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS From 3dae0eec3f2812e8efaa8f8e3216050d50c6a6cb Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 25 Sep 2025 04:40:42 +0000 Subject: [PATCH 2/7] use get_kernel Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 3 +-- src/transformers/models/qwen3/modeling_qwen3.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 62c9c14bfcc8..05f580771888 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -25,7 +25,6 @@ LayerRepository, Mode, get_kernel, - load_kernel, register_kernel_mapping, replace_kernel_forward_from_hub, use_kernel_forward_from_hub, @@ -117,7 +116,7 @@ register_kernel_mapping(_KERNEL_MAPPING) # Preload the rotary kernel as it's used in many models. - rotary_kernel = load_kernel(repo_id="kernels-community/rotary") + rotary_kernel = get_kernel(repo_id="kernels-community/rotary") except ImportError: _kernels_available = False diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index f7df7b9a6b42..d06c76c93a44 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -125,10 +125,10 @@ def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) - + q_rotated = q.clone() k_rotated = k.clone() - + # Get half dimension for rotation half_dim = q.shape[-1] // 2 q1 = q_rotated[..., :half_dim] @@ -221,6 +221,7 @@ def forward( attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, + use_kernels: Optional[bool] = False, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] @@ -231,7 +232,7 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - if rotary_kernel: + if use_kernels and rotary_kernel: query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) else: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -284,6 +285,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + use_kernels: Optional[bool] = False, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states @@ -297,6 +299,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + use_kernels=use_kernels, **kwargs, ) hidden_states = residual + hidden_states @@ -394,6 +397,7 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + use_kernels: Optional[bool] = False, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): @@ -447,6 +451,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + use_kernels=use_kernels, **kwargs, ) @@ -517,6 +522,7 @@ def forward( inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, + use_kernels=self.use_kernels, **kwargs, ) From bdefedca3db71a89982f5f6380e248ffadb6c751 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 25 Sep 2025 04:43:55 +0000 Subject: [PATCH 3/7] add rotary kernel support to Qwen3 model Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 6 +++ .../models/qwen3/modeling_qwen3.py | 40 ++++++++++++++++++- .../models/qwen3/modular_qwen3.py | 1 + 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 5be21e2f9a51..05f580771888 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -115,6 +115,9 @@ register_kernel_mapping(_KERNEL_MAPPING) + # Preload the rotary kernel as it's used in many models. + rotary_kernel = get_kernel(repo_id="kernels-community/rotary") + except ImportError: _kernels_available = False @@ -138,6 +141,8 @@ def replace_kernel_forward_from_hub(*args, **kwargs): def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") + rotary_kernel = None + def is_kernel(attn_implementation: Optional[str]) -> bool: """Check whether `attn_implementation` matches a kernel pattern from the hub.""" @@ -201,4 +206,5 @@ def load_and_register_kernel(attn_implementation: str) -> None: "use_kernel_forward_from_hub", "register_kernel_mapping", "replace_kernel_forward_from_hub", + "rotary_kernel", ] diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 81b16c4ee6b6..d06c76c93a44 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -28,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub +from ...integrations.hub_kernels import rotary_kernel from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import ( @@ -117,6 +118,34 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Rotary kernel implementation wrapper + Adapts rotary kernels implementation to match HuggingFace apply_rotary_pos_emb signature + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_rotated = q.clone() + k_rotated = k.clone() + + # Get half dimension for rotation + half_dim = q.shape[-1] // 2 + q1 = q_rotated[..., :half_dim] + q2 = q_rotated[..., half_dim:] + k1 = k_rotated[..., :half_dim] + k2 = k_rotated[..., half_dim:] + if cos.shape[-1] != half_dim: + # Trim cos/sin to match half_dim + cos = cos[..., :half_dim] + sin = sin[..., :half_dim] + + # Apply rotary embedding using our kernel + rotary_kernel.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_kernel.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return q_rotated, k_rotated + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -192,6 +221,7 @@ def forward( attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, + use_kernels: Optional[bool] = False, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: input_shape = hidden_states.shape[:-1] @@ -202,7 +232,10 @@ def forward( value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if use_kernels and rotary_kernel: + query_states, key_states = apply_rotary_kernel(query_states, key_states, cos, sin, cache_position) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -252,6 +285,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + use_kernels: Optional[bool] = False, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states @@ -265,6 +299,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + use_kernels=use_kernels, **kwargs, ) hidden_states = residual + hidden_states @@ -362,6 +397,7 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + use_kernels: Optional[bool] = False, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): @@ -415,6 +451,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + use_kernels=use_kernels, **kwargs, ) @@ -485,6 +522,7 @@ def forward( inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, + use_kernels=self.use_kernels, **kwargs, ) diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index f1e38841faf4..ef764b0e90b0 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -19,6 +19,7 @@ import torch from ...cache_utils import Cache +from ...integrations.hub_kernels import rotary_kernel from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS From ceac675c4341673d19eea5e1211f1f4c2d6fa6d6 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 25 Sep 2025 04:52:39 +0000 Subject: [PATCH 4/7] add rotary kernel support for Qwen3 model Signed-off-by: Hu, Yabai Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 1 + src/transformers/models/qwen3/modeling_qwen3.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 05f580771888..8a5c237612a6 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -25,6 +25,7 @@ LayerRepository, Mode, get_kernel, + load_kernel, register_kernel_mapping, replace_kernel_forward_from_hub, use_kernel_forward_from_hub, diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index d06c76c93a44..4174989970ff 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -128,7 +128,7 @@ def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): q_rotated = q.clone() k_rotated = k.clone() - + # Get half dimension for rotation half_dim = q.shape[-1] // 2 q1 = q_rotated[..., :half_dim] From eee8c53cc625e5c83884a4924ce179223ecf31a7 Mon Sep 17 00:00:00 2001 From: "Hu, Yabai" Date: Thu, 25 Sep 2025 02:00:00 +0000 Subject: [PATCH 5/7] add rotary kernel support for Qwen3 model Signed-off-by: Hu, Yabai --- src/transformers/integrations/hub_kernels.py | 1 + src/transformers/models/qwen3/modeling_qwen3.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 05f580771888..8a5c237612a6 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -25,6 +25,7 @@ LayerRepository, Mode, get_kernel, + load_kernel, register_kernel_mapping, replace_kernel_forward_from_hub, use_kernel_forward_from_hub, diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index d06c76c93a44..4174989970ff 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -128,7 +128,7 @@ def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): q_rotated = q.clone() k_rotated = k.clone() - + # Get half dimension for rotation half_dim = q.shape[-1] // 2 q1 = q_rotated[..., :half_dim] From 671764ac7be984de6fed64ed333613b281be207f Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 25 Sep 2025 04:40:42 +0000 Subject: [PATCH 6/7] use get_kernel Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 1 - src/transformers/models/qwen3/modeling_qwen3.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 8a5c237612a6..05f580771888 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -25,7 +25,6 @@ LayerRepository, Mode, get_kernel, - load_kernel, register_kernel_mapping, replace_kernel_forward_from_hub, use_kernel_forward_from_hub, diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 4174989970ff..d06c76c93a44 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -128,7 +128,7 @@ def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): q_rotated = q.clone() k_rotated = k.clone() - + # Get half dimension for rotation half_dim = q.shape[-1] // 2 q1 = q_rotated[..., :half_dim] From f991bf7f6a8a7944fd28d1a2b68ca7f2ebd11a9a Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 25 Sep 2025 04:52:39 +0000 Subject: [PATCH 7/7] add rotary kernel support for Qwen3 model Signed-off-by: Liu, Kaixuan --- src/transformers/integrations/hub_kernels.py | 1 + src/transformers/models/qwen3/modeling_qwen3.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 05f580771888..8a5c237612a6 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -25,6 +25,7 @@ LayerRepository, Mode, get_kernel, + load_kernel, register_kernel_mapping, replace_kernel_forward_from_hub, use_kernel_forward_from_hub, diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index d06c76c93a44..4174989970ff 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -128,7 +128,7 @@ def apply_rotary_kernel(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): q_rotated = q.clone() k_rotated = k.clone() - + # Get half dimension for rotation half_dim = q.shape[-1] // 2 q1 = q_rotated[..., :half_dim]